Last commit not found
""" | |
This file tests whether the model predictions for MHNfs match the predictions made on | |
the JKU development server (varified model, server conda env with spec. packages ...) | |
""" | |
#--------------------------------------------------------------------------------------- | |
# Dependencies | |
import pytest | |
import torch | |
import pandas as pd | |
from prediction_pipeline import ActivityPredictor | |
#--------------------------------------------------------------------------------------- | |
# Define tests | |
class TestActivityPredictor: | |
def test_mhnfs_prediction(self, model_input_query, model_input_support_actives, | |
model_input_support_inactives, model_predictions): | |
# Load model | |
predictor = ActivityPredictor() | |
# Define additional inputs to model - i.e. support set sizes | |
support_actives_size = torch.tensor(model_input_support_actives.shape[1]) | |
support_inactives_size = torch.tensor(model_input_support_inactives.shape[1]) | |
# Make predictions | |
predictions = predictor.model( | |
model_input_query, | |
model_input_support_actives, | |
model_input_support_inactives, | |
support_actives_size, | |
support_inactives_size | |
).detach() | |
# Compare predictions | |
assert torch.allclose(predictions, model_predictions, atol=0.01, rtol=0.) | |
def test_query_mol_return(self): | |
# Support set | |
support_actives_smiles = ["CCCCCCCCNC(C)C(O)c1ccc(SC(C)C)cc1"] | |
support_inactives_smiles = ["CCN(CC)C(=S)SSC(=S)N(CC)CCCCC"] | |
# Load activity predictor | |
predictor = ActivityPredictor() | |
# Check 1: Query mols given as a list | |
query_smiles = ["CCCCCCCCNC(C)C(O)c1ccc(SC(C)C)cc1", | |
"CCN(CC)C(=S)SSC(=S)N(CC)CC"] | |
_ = predictor.predict(query_smiles, support_actives_smiles, | |
support_inactives_smiles) | |
query_output = predictor._return_query_mols_as_list() | |
assert query_output == query_smiles | |
# Check 2: Query mols given as a string | |
query_smiles_str = ("CCCCCCCCNC(C)C(O)c1ccc(SC(C)C)cc1," | |
"CCN(CC)C(=S)SSC(=S)N(CC)CC") | |
_ = predictor.predict(query_smiles_str, support_actives_smiles, | |
support_inactives_smiles) | |
query_output = predictor._return_query_mols_as_list() | |
assert query_output == query_smiles | |
# Check 3: Query mols given as a pd.Series | |
query_smiles_series = pd.DataFrame({"smiles": | |
["CCCCCCCCNC(C)C(O)c1ccc(SC(C)C)cc1", "CCN(CC)C(=S)SSC(=S)N(CC)CC"]}) | |
_ = predictor.predict(query_smiles_series, support_actives_smiles, | |
support_inactives_smiles) | |
query_output = predictor._return_query_mols_as_list() | |
assert query_output == query_smiles | |
# Check 4: Query molecules storage is None | |
predictor.query_molecules = None | |
with pytest.raises(ValueError): | |
predictor._return_query_mols_as_list() | |
# Check 5: Other data types | |
predictor.query_molecules = 123 # any other data type | |
with pytest.raises(TypeError): | |
predictor._return_query_mols_as_list() | |