diff --git a/nomad_dos_fingerprints/DOSfingerprint.py b/nomad_dos_fingerprints/DOSfingerprint.py index b0a76278e52eb9fd949a240f92c4d5354fa1e7a0..77e59b03a4b24d05a38db2171fccad06c4e53305 100644 --- a/nomad_dos_fingerprints/DOSfingerprint.py +++ b/nomad_dos_fingerprints/DOSfingerprint.py @@ -1,18 +1,21 @@ import numpy as np from bitarray import bitarray +from functools import partial from .grid import Grid +from .similarity import tanimoto_similarity ELECTRON_CHARGE = 1.602176565e-19 class DOSFingerprint(): - def __init__(self, stepsize = 0.05): + def __init__(self, stepsize = 0.05, similarity_function = tanimoto_similarity, **kwargs): self.bins = '' self.indices = [] self.stepsize = stepsize self.filling_factor = 0 self.grid_id = None + self.set_similarity_function(similarity_function, **kwargs) def calculate(self, dos_energies, dos_values, grid_id = 'dg_cut:56:-2:7:(-10, 5)', unit_cell_volume = 1, n_atoms = 1): energy, dos = self._convert_dos(dos_energies, dos_values, unit_cell_volume = unit_cell_volume, n_atoms = n_atoms) @@ -35,6 +38,15 @@ class DOSFingerprint(): self.filling_factor = fp_dict['filling_factor'] return self + def set_similarity_function(self, similarity_function, **kwargs): + self.similarity_function = partial(similarity_function, **kwargs) + + def get_similarity(self, fingerprint): + return self.similarity_function(self, fingerprint) + + def get_similarities(self, list_of_fingerprints): + return np.array([self.similarity_function(self, fp) for fp in list_of_fingerprints]) + def _integrate_to_bins(self, xs, ys): """ Performs stepwise numerical integration of ``ys`` over the range of ``xs``. The stepsize of the generated histogram is controlled by DOSFingerprint().stepsize. @@ -57,7 +69,7 @@ class DOSFingerprint(): dos_channels = [np.array(values) for values in dos] dos = sum(dos_channels) * ELECTRON_CHARGE * unit_cell_volume * n_atoms return energy, dos - + def _binary_bin(self, dos_value, grid_bins): bin_dos = '' for grid_bin in grid_bins: diff --git a/tests/test_DOSfingerprint.py b/tests/test_DOSfingerprint.py index c5579a68bcb7e2b7af22271a2edf1f897dfa627f..bc6ed940413f662da02db8bf7d75cfba3407c945 100644 --- a/tests/test_DOSfingerprint.py +++ b/tests/test_DOSfingerprint.py @@ -2,7 +2,7 @@ import pytest import numpy as np from nomad_dos_fingerprints import DOSFingerprint, tanimoto_similarity -from nomad_dos_fingerprints.DOSfingerprint import ELECTRON_CHARGE +from nomad_dos_fingerprints.DOSfingerprint import ELECTRON_CHARGE def test_integrate_to_bins(): test_data_x = np.linspace(0, np.pi, num = 1000) diff --git a/tests/test_functional.py b/tests/test_functional.py index bc3e91c221adb022dce30443793316107e71da32..fb0fd7b492f5355600b4cd84ac81dd0aa524f4ae 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -4,9 +4,9 @@ import numpy as np with open(os.path.join(os.path.dirname(__file__), 'fingerprint_generation_test_data.json'), 'r') as test_data_file: test_data = json.load(test_data_file) - + def test_fingerprint_values(): - + for fp, mid in test_data['fingerprints']: raw_data = test_data[mid] new_fingerprint = DOSFingerprint().calculate(raw_data['dos_energies'], raw_data['dos_values']) @@ -15,7 +15,7 @@ def test_fingerprint_values(): old_fingerprint.indices = json.loads(fp)['indices'] old_fingerprint.grid_id = new_fingerprint.grid_id assert old_fingerprint.indices == new_fingerprint.indices - assert np.isclose(tanimoto_similarity(old_fingerprint, new_fingerprint),1, atol=5e-2) + assert np.isclose(old_fingerprint.get_similarity(new_fingerprint),1, atol=5e-2) def test_materials_similarity(): @@ -25,10 +25,7 @@ def test_materials_similarity(): raw_data = [test_data[mid] for mid in mids] new_fingerprints = [DOSFingerprint().calculate(entry['dos_energies'], entry['dos_values']) for entry in raw_data] matrix = [] - for fp1 in new_fingerprints: - row = [] - for fp2 in new_fingerprints: - row.append(tanimoto_similarity(fp1,fp2)) - matrix.append(row) + for fp in new_fingerprints: + matrix.append(fp.get_similarities(new_fingerprints)) print(matrix - np.array(similarity_matrix)) - assert np.isclose(similarity_matrix, matrix, atol = 5e-2).all() \ No newline at end of file + assert np.isclose(similarity_matrix, matrix, atol = 5e-2).all()