File size: 8,084 Bytes
50fe1a2
6bb1bdf
 
 
 
50fe1a2
6bb1bdf
 
 
 
 
 
 
 
 
 
 
 
50fe1a2
6bb1bdf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import gradio as gr
from huggingface_hub import HfApi, get_collection, list_collections
from utils import MolecularPropertyPredictionModel, task_types, dataset_descriptions
import pandas as pd
import os

def get_models():
    # this is the collection id for the molecular property prediction models
    collection = get_collection("ChemFM/molecular-property-prediction-6710141ffc31f31a47d6fc0c")
    models = dict()
    for item in collection.items:
        if item.item_type == "model":
            item_name = item.item_id.split("/")[-1]
            models[item_name] = item.item_id
            assert item_name in task_types, f"{item_name} is not in the task_types"
            assert item_name in dataset_descriptions, f"{item_name} is not in the dataset_descriptions"
    
    return models

candidate_models = get_models()
properties = list(candidate_models.keys())
model = MolecularPropertyPredictionModel()

def get_description(property_name):
    return dataset_descriptions[property_name]

def predict_single_label(smiles, property_name):
    adapter_id = candidate_models[property_name]
    info = model.swith_adapter(property_name, adapter_id)

    running_status = None
    if info == "keep":
        running_status = "Adapter is the same as the current one"
        #print("Adapter is the same as the current one")
    elif info == "switched":
        running_status = "Adapter is switched successfully"
        #print("Adapter is switched successfully")
    elif info == "error":
        running_status = "Adapter is not found"
        #print("Adapter is not found")
        return "NA", running_status
    else:
        running_status = "Unknown error"
        return "NA", running_status
    
    #prediction = model.predict(smiles, property_name, adapter_id)
    prediction = model.predict_single_smiles(smiles, task_types[property_name])
    if prediction is None:
        return "NA", "Invalid SMILES string"
    
    # if the prediction is a float, round it to 3 decimal places
    if isinstance(prediction, float):
        prediction = round(prediction, 3)

    return prediction, "Prediction is done"

def predict_file(file, property_name):
    adapter_id = candidate_models[property_name]
    info = model.swith_adapter(property_name, adapter_id)

    running_status = None
    if info == "keep":
        running_status = "Adapter is the same as the current one"
        #print("Adapter is the same as the current one")
    elif info == "switched":
        running_status = "Adapter is switched successfully"
        #print("Adapter is switched successfully")
    elif info == "error":
        running_status = "Adapter is not found"
        #print("Adapter is not found")
        return None, None, file, running_status
    else:
        running_status = "Unknown error"
        return None, None, file, running_status
    
    df = pd.read_csv(file)
    # we have already checked the file contains the "smiles" column
    df = model.predict_file(df, task_types[property_name])
    # we should save this file to the disk to be downloaded
    # rename the file to have "_prediction" suffix
    prediction_file = file.replace(".csv", "_prediction.csv") if file.endswith(".csv") else file.replace(".smi", "_prediction.csv")
    print(file, prediction_file)
    # save the file to the disk
    df.to_csv(prediction_file, index=False)
    
    return gr.update(visible=False), gr.DownloadButton(label="Download", value=prediction_file, visible=True), prediction_file, "Prediction is done"

def validate_file(file):
    try:
        if file.endswith(".csv"):
            df = pd.read_csv(file)
            if "smiles" not in df.columns:
                # we should clear the file input
                return "Invalid file content. The csv file must contain column named 'smiles'", \
                         None, gr.update(visible=False), gr.update(visible=False)
            
            # check the length of the smiles
            length = len(df["smiles"])

        elif file.endswith(".smi"):
            return "Invalid file extension", \
                    None, gr.update(visible=False), gr.update(visible=False)

        else:
            return "Invalid file extension", \
                    None, gr.update(visible=False), gr.update(visible=False)
    except Exception as e:
        return "Invalid file content.", \
                None, gr.update(visible=False), gr.update(visible=False)
    
    if length > 100: 
        return "The space does not support the file containing more than 100 SMILES", \
                None, gr.update(visible=False), gr.update(visible=False)

    return "Valid file", file, gr.update(visible=True), gr.update(visible=False)
    

def raise_error(status):
    if status != "Valid file":
        raise gr.Error(status)
    return None


def clear_file(download_button):
    # we might need to delete the prediction file and uploaded file
    prediction_path = download_button
    print(prediction_path)
    if prediction_path and os.path.exists(prediction_path):
        os.remove(prediction_path)
        original_data_file_0 = prediction_path.replace("_prediction.csv", ".csv")
        original_data_file_1 = prediction_path.replace("_prediction.csv", ".smi")
        if os.path.exists(original_data_file_0):
            os.remove(original_data_file_0)
        if os.path.exists(original_data_file_1):
            os.remove(original_data_file_1)
    #if os.path.exists(file):
    #    os.remove(file)
    #prediction_file = file.replace(".csv", "_prediction.csv") if file.endswith(".csv") else file.replace(".smi", "_prediction.csv")
    #if os.path.exists(prediction_file):
    #    os.remove(prediction_file)
    

    return gr.update(visible=False), gr.update(visible=False), None

def build_inference():

    with gr.Blocks() as demo:
        # first row - Dropdown input
        #with gr.Row():
        dropdown = gr.Dropdown(properties, label="Property", value=properties[0])
        description_box = gr.Textbox(label="Property description", lines=5,
                                     interactive=False,
                                     value=dataset_descriptions[properties[0]])
        # third row - Textbox input and prediction label
        with gr.Row(equal_height=True):
            with gr.Column():
                textbox = gr.Textbox(label="Molecule SMILES", type="text", placeholder="Provide a SMILES string here",
                                     lines=1)
                predict_single_smiles_button = gr.Button("Predict", size='sm')
            prediction = gr.Label("Prediction will appear here")

        running_terminal_label = gr.Textbox(label="Running status", type="text", placeholder=None, lines=10, interactive=False)
        
        input_file = gr.File(label="Molecule file",
                       file_count='single',
                       file_types=[".smi", ".csv"], height=300)
        predict_file_button = gr.Button("Predict", size='sm', visible=False)
        download_button = gr.DownloadButton("Download", size='sm', visible=False)

        # dropdown change event
        dropdown.change(get_description, inputs=dropdown, outputs=description_box)
        # predict single button click event
        predict_single_smiles_button.click(predict_single_label, inputs=[textbox, dropdown], outputs=[prediction, running_terminal_label])
        # input file upload event
        file_status = gr.State()
        input_file.upload(fn=validate_file, inputs=input_file, outputs=[file_status, input_file, predict_file_button, download_button]).success(raise_error, inputs=file_status, outputs=file_status)
        # input file clear event
        input_file.clear(fn=clear_file, inputs=[download_button], outputs=[predict_file_button, download_button, input_file])
        # predict file button click event
        predict_file_button.click(predict_file, inputs=[input_file, dropdown], outputs=[predict_file_button, download_button, input_file, running_terminal_label])
        
    return demo


demo = build_inference() 

if __name__ == '__main__':
    demo.launch()