File size: 2,328 Bytes
78960a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# prediction.py

import pandas as pd
from utils import extract_genus, load_rules
import re

rules = load_rules()

def encode_input(input_data, encoders):
    """Encode input data using label encoders."""
    encoded_data = {}
    for col, value in input_data.items():
        if col in encoders:
            try:
                encoded_data[col] = encoders[col][value]
            except KeyError:
                return f"Error: '{value}' is not a valid option for '{col}'"
        else:
            encoded_data[col] = value  # Numeric values pass through
    return encoded_data

def decode_output(encoded_value, encoder_name, encoders):
    """Decode encoded values back to original labels."""
    if encoder_name in encoders:
        reverse_mapping = {v: k for k, v in encoders[encoder_name].items()}
        return reverse_mapping.get(encoded_value, "Unknown")
    return "Unknown"

def predict_susceptibility(input_data, model, encoders):
    """Predict susceptibility using the model and rules."""
    try:
        organism = input_data["organism"]
        antibiotic = input_data["antibiotic"]
        genus = extract_genus(organism)

        # Check if any rule matches using regex
        matching_rule = next((rules[key] for key in rules if re.match(fr"^{key[0]}", genus) and key[1] == antibiotic), None)

        rule_output = matching_rule if matching_rule else "No specific rule found"

        # Model prediction
        encoded_data = encode_input(input_data, encoders)
        if isinstance(encoded_data, str):  
            return encoded_data

        encoded_df = pd.DataFrame([encoded_data])
        model_prediction = model.predict(encoded_df)[0]
        model_output = decode_output(model_prediction, "susceptibility", encoders)

        # Apply rule override if needed
        if rule_output != "No specific rule found" and rule_output != model_output:
            final_output = rule_output
            reason = "Rule-based correction applied."
        else:
            final_output = model_output
            reason = "Model prediction used."

        return {
            "Final Output": final_output,
            "Rule Guidance": rule_output,
            "Model Prediction": model_output,
            "Decision Reason": reason
        }
    except Exception as e:
        return {"Error": str(e)}