EAV123's picture
Upload 16 files
78960a3 verified
# prediction.py
import pandas as pd
from utils import extract_genus, load_rules
import re
rules = load_rules()
def encode_input(input_data, encoders):
"""Encode input data using label encoders."""
encoded_data = {}
for col, value in input_data.items():
if col in encoders:
try:
encoded_data[col] = encoders[col][value]
except KeyError:
return f"Error: '{value}' is not a valid option for '{col}'"
else:
encoded_data[col] = value # Numeric values pass through
return encoded_data
def decode_output(encoded_value, encoder_name, encoders):
"""Decode encoded values back to original labels."""
if encoder_name in encoders:
reverse_mapping = {v: k for k, v in encoders[encoder_name].items()}
return reverse_mapping.get(encoded_value, "Unknown")
return "Unknown"
def predict_susceptibility(input_data, model, encoders):
"""Predict susceptibility using the model and rules."""
try:
organism = input_data["organism"]
antibiotic = input_data["antibiotic"]
genus = extract_genus(organism)
# Check if any rule matches using regex
matching_rule = next((rules[key] for key in rules if re.match(fr"^{key[0]}", genus) and key[1] == antibiotic), None)
rule_output = matching_rule if matching_rule else "No specific rule found"
# Model prediction
encoded_data = encode_input(input_data, encoders)
if isinstance(encoded_data, str):
return encoded_data
encoded_df = pd.DataFrame([encoded_data])
model_prediction = model.predict(encoded_df)[0]
model_output = decode_output(model_prediction, "susceptibility", encoders)
# Apply rule override if needed
if rule_output != "No specific rule found" and rule_output != model_output:
final_output = rule_output
reason = "Rule-based correction applied."
else:
final_output = model_output
reason = "Model prediction used."
return {
"Final Output": final_output,
"Rule Guidance": rule_output,
"Model Prediction": model_output,
"Decision Reason": reason
}
except Exception as e:
return {"Error": str(e)}