import gradio as gr import numpy as np import pandas as pd from tempfile import NamedTemporaryFile from PIL import Image from rdkit import RDLogger from sklearn.model_selection import train_test_split from molecule_generation_helpers import * from property_prediction_helpers import * DEBUG_VISIBLE = False RDLogger.logger().setLevel(RDLogger.ERROR) # Predefined dataset paths (these should be adjusted to your file paths) predefined_datasets = { " ": " ", "BACE": "./data/bace/train.csv, ./data/bace/test.csv, smiles, Class", "ESOL": "./data/esol/train.csv, ./data/esol/test.csv, smiles, prop", } # Models models_enabled = [ "MorganFingerprint", "SMI-TED", "SELFIES-TED", "MHG-GED", "POS-EGNN", ] blank_df = pd.DataFrame({"id": [], "Model": [], "Score": []}) # Function to load a predefined dataset from the local path def load_predefined_dataset(dataset_name): val = predefined_datasets.get(dataset_name) if val: try: df = pd.read_csv(val.split(",")[0]) return ( df.head(), gr.update(choices=list(df.columns), value=None), gr.update(choices=list(df.columns), value=None), dataset_name.lower(), ) except: pass else: dataset_name = "Custom" return ( pd.DataFrame(), gr.update(choices=[], value=None), gr.update(choices=[], value=None), dataset_name.lower(), ) # Function to handle dataset selection (predefined or custom) def handle_dataset_selection(selected_dataset, state): state["dataset_name"] = ( selected_dataset if selected_dataset in predefined_datasets else "CUSTOM" ) # Show file upload fields for train and test datasets if "Custom Dataset" is selected task_type = ( "Classification" if selected_dataset == "BACE" else "Regression" if selected_dataset == "ESOL" else None ) return ( gr.update(visible=selected_dataset not in predefined_datasets or DEBUG_VISIBLE), task_type, ) # Function to select input and output columns and display a message def select_columns(input_column, output_column, train_data, test_data, state): if train_data and test_data and input_column and output_column: return f"{train_data.name},{test_data.name},{input_column},{output_column},{state['dataset_name']}" return gr.update() # Function to display the head of the uploaded CSV file def display_csv_head(file): if file is not None: # Load the CSV file into a DataFrame df = pd.read_csv(file.name) return ( df.head(), gr.update(choices=list(df.columns)), gr.update(choices=list(df.columns)), ) return pd.DataFrame(), gr.update(choices=[]), gr.update(choices=[]) def process_custom_file(file, selected_dataset): if file and os.path.getsize(file.name) < 50 * 1024: df = pd.read_csv(file.name) if "input" in df.columns and "output" in df.columns: train, test = train_test_split(df, test_size=0.2) with NamedTemporaryFile( prefix="fm4m-train-", suffix=".csv", delete=False ) as train_file: train.to_csv(train_file.name, index=False) with NamedTemporaryFile( prefix="fm4m-test-", suffix=".csv", delete=False ) as test_file: test.to_csv(test_file.name, index=False) task_type = ( "Classification" if df["output"].dtype == np.int64 else "Regression" ) return train_file.name, test_file.name, "input", "output", task_type return ( None, None, None, None, gr.update() if selected_dataset in predefined_datasets else None, ) def update_plot_choices(current, state): choices = [] if state.get("roc_auc") is not None: choices.append("ROC-AUC") if state.get("RMSE") is not None: choices.append("Parity Plot") if state.get("x_batch") is not None: choices.append("Latent Space") if current in choices: return gr.update(choices=choices) return gr.update(choices=choices, value=None if len(choices) == 0 else choices[0]) def log_selected(df: pd.DataFrame, evt: gr.SelectData, state): state.update(state["results"].get(df.at[evt.index[0], 'id'], {})) # Dictionary for SMILES strings and corresponding images (you can replace with your actual image paths) smiles_image_mapping = { # Example SMILES for ethanol "Mol 1": { "smiles": "C=C(C)CC(=O)NC[C@H](CO)NC(=O)C=Cc1ccc(C)c(Cl)c1", "image": "img/img1.png", }, # Example SMILES for butane "Mol 2": { "smiles": "C=CC1(CC(=O)NC[C@@H](CCCC)NC(=O)c2cc(Cl)cc(Br)c2)CC1", "image": "img/img2.png", }, # Example SMILES for ethylamine "Mol 3": { "smiles": "C=C(C)C[C@H](NC(C)=O)C(=O)N1CC[C@H](NC(=O)[C@H]2C[C@@]2(C)Br)C(C)(C)C1", "image": "img/img3.png", }, # Example SMILES for diethyl ether "Mol 4": { "smiles": "C=C1CC(CC(=O)N[C@H]2CCN(C(=O)c3ncccc3SC)C23CC3)C1", "image": "img/img4.png", }, # Example SMILES for chloroethane "Mol 5": { "smiles": "C=CCS[C@@H](C)CC(=O)OCC", "image": "img/img5.png", }, } # Load images for selection def load_image(path): try: return Image.open(smiles_image_mapping[path]["image"]) except: pass # Function to handle image selection def handle_image_selection(image_key): if not image_key: return None, None smiles = smiles_image_mapping[image_key]["smiles"] mol_image = smiles_to_image(smiles) return smiles, mol_image # Introduction with gr.Blocks() as introduction: with open("INTRODUCTION.md") as f: gr.Markdown(f.read(), sanitize_html=False) # Property Prediction with gr.Blocks() as property_prediction: state = gr.State({"model_name": "Default - Auto", "results": {}}) gr.HTML( '''

Task : Property Prediction
Models are finetuned with different combination of modalities on the uploaded or selected built data set.

''' ) with gr.Row(): with gr.Column(): # Dropdown menu for predefined datasets including "Custom Dataset" option dataset_selector = gr.Dropdown( label="Select Dataset", choices=list(predefined_datasets.keys()) + ["Custom Dataset"], ) # Display the message for selected columns selected_columns_message = gr.Textbox( label="Selected Columns Info", visible=DEBUG_VISIBLE ) with gr.Accordion( "Custom Dataset Settings", open=True, visible=DEBUG_VISIBLE ) as settings: # File upload options for custom dataset (train and test) custom_file = gr.File( label="Upload Custom Dataset", file_types=[".csv"], ) train_file = gr.File( label="Upload Custom Train Dataset", file_types=[".csv"], visible=False, ) train_display = gr.Dataframe( label="Train Dataset Preview (First 5 Rows)", interactive=False, visible=DEBUG_VISIBLE, ) test_file = gr.File( label="Upload Custom Test Dataset", file_types=[".csv"], visible=False, ) test_display = gr.Dataframe( label="Test Dataset Preview (First 5 Rows)", interactive=False, visible=DEBUG_VISIBLE, ) # Predefined dataset displays predefined_display = gr.Dataframe( label="Predefined Dataset Preview (First 5 Rows)", interactive=False, visible=DEBUG_VISIBLE, ) # Dropdowns for selecting input and output columns for the custom dataset input_column_selector = gr.Dropdown( label="Select Input Column", choices=[], allow_custom_value=True, visible=DEBUG_VISIBLE, ) output_column_selector = gr.Dropdown( label="Select Output Column", choices=[], allow_custom_value=True, visible=DEBUG_VISIBLE, ) # When a custom train file is uploaded, display its head and update column selectors train_file.change( display_csv_head, inputs=train_file, outputs=[ train_display, input_column_selector, output_column_selector, ], ) # When a custom test file is uploaded, display its head test_file.change( display_csv_head, inputs=test_file, outputs=[ test_display, input_column_selector, output_column_selector, ], ) model_checkbox = gr.CheckboxGroup( choices=models_enabled, label="Select Model", visible=DEBUG_VISIBLE ) task_radiobutton = gr.Radio( choices=["Classification", "Regression"], label="Task Type", visible=DEBUG_VISIBLE, ) # When a dataset is selected, show either file upload fields (for custom) or load predefined datasets # When a predefined dataset is selected, load its head and update column selectors dataset_selector.change(lambda: None, outputs=custom_file).then( handle_dataset_selection, inputs=[dataset_selector, state], outputs=[settings, task_radiobutton], ).then( load_predefined_dataset, inputs=dataset_selector, outputs=[ predefined_display, input_column_selector, output_column_selector, selected_columns_message, ], ) custom_file.change( process_custom_file, inputs=[custom_file, dataset_selector], outputs=[ train_file, test_file, input_column_selector, output_column_selector, task_radiobutton, ], ) eval_clear_button = gr.Button("Clear") eval_button = gr.Button("Submit", variant="primary") step_slider = gr.Slider( minimum=0, maximum=31, value=0, label="Progress", show_label=True, interactive=False, visible=False, ) # Right Column with gr.Column(): log_table = gr.Dataframe(value=blank_df, interactive=False) plot_radio = gr.Radio(choices=[], label="Select Plot Type") plot_output = gr.Plot(label="Visualization") log_table.select(log_selected, [log_table, state]).success( update_plot_choices, inputs=[plot_radio, state], outputs=plot_radio ).then(display_plot, inputs=[plot_radio, state], outputs=plot_output) def clear_eval(state): state["results"] = {} return None, gr.update(choices=[], value=None), blank_df def eval_part(part, step, selector, show_progress=False): return ( part.then( lambda: [models_enabled[x] for x in selector], outputs=model_checkbox, ) .then( evaluate_and_log, inputs=[ model_checkbox, selected_columns_message, task_radiobutton, log_table, state, ], outputs=log_table, show_progress=show_progress, ) .then(lambda: step, outputs=step_slider, show_progress=False) ) part = ( eval_button.click( lambda: ( gr.update(interactive=False), gr.update(interactive=False), ), outputs=[eval_clear_button, eval_button], ) .then( select_columns, inputs=[ input_column_selector, output_column_selector, train_file, test_file, state, ], outputs=selected_columns_message, ) .then( clear_eval, inputs=state, outputs=[ plot_output, plot_radio, log_table, ], ) ) part = part.then( lambda: gr.update(value=0, visible=True), outputs=step_slider, show_progress=False, ) from itertools import combinations part_index = 1 # start index for r in range(1, 6): # for group sizes 2 to 5 for combo in combinations(range(5), r): # 5 items: indices 0 to 4 if list(combo) == [0]: part = eval_part(part, part_index, list(combo), True) else: part = eval_part(part, part_index, list(combo)) part_index += 1 """part = eval_part(part, 1, [0], True) part = eval_part(part, 2, [1]) part = eval_part(part, 3, [2]) part = eval_part(part, 4, [3]) part = eval_part(part, 5, [4]) part = eval_part(part, 6, [0, 1]) part = eval_part(part, 7, [0, 2]) part = eval_part(part, 8, [0, 3]) part = eval_part(part, 9, [0, 4]) part = eval_part(part, 10, [1, 2]) part = eval_part(part, 11, [1, 3]) part = eval_part(part, 12, [1, 4]) part = eval_part(part, 13, [2, 3]) part = eval_part(part, 14, [2, 4]) part = eval_part(part, 15, [3, 4]) part = eval_part(part, 16, [0, 1, 2]) part = eval_part(part, 17, [0, 1, 3]) part = eval_part(part, 18, [0, 1, 4]) part = eval_part(part, 19, [0, 2, 3]) part = eval_part(part, 20, [0, 2, 4]) part = eval_part(part, 21, [0, 3, 4]) part = eval_part(part, 22, [1, 2, 3]) part = eval_part(part, 23, [1, 2, 4]) part = eval_part(part, 24, [1, 3, 4]) part = eval_part(part, 25, [2, 3, 4]) part = eval_part(part, 26, [0, 1, 2, 3]) part = eval_part(part, 27, [0, 1, 2, 4]) part = eval_part(part, 28, [0, 1, 3, 4]) part = eval_part(part, 29, [0, 2, 3, 4]) part = eval_part(part, 30, [1, 2, 3, 4]) part = eval_part(part, 31, [0,1, 2, 3, 4])""" part = part.then( lambda: gr.update(visible=False), outputs=step_slider, show_progress=False, ) part.then( lambda: ( gr.update(interactive=True), gr.update(interactive=True), ), outputs=[eval_clear_button, eval_button], ) plot_radio.change( display_plot, inputs=[plot_radio, state], outputs=plot_output ) eval_clear_button.click( clear_eval, inputs=state, outputs=[ plot_output, plot_radio, log_table, ], ).then(lambda: " ", outputs=dataset_selector) # Molecule Generation with gr.Blocks() as molecule_generation: gr.HTML( '''

Task : Molecule Generation
Generate a new molecule similar to the initial molecule with better drug-likeness and synthetic accessibility.

''' ) with gr.Row(): with gr.Column(): smiles_input = gr.Textbox(label="Input SMILES String") image_display = gr.Image(label="Molecule Image", height=250, width=250) # Show images for selection with gr.Accordion("Select from sample molecules", open=False): image_selector = gr.Radio( choices=list(smiles_image_mapping.keys()), label="Select from sample molecules", value=None, ) image_selector.change(load_image, image_selector, image_display) clear_button = gr.Button("Clear") generate_button = gr.Button("Submit", variant="primary") # Right Column with gr.Column(): gen_image_display = gr.Image( label="Generated Molecule Image", height=250, width=250 ) generated_output = gr.Textbox(label="Generated Output") property_table = gr.Dataframe(label="Molecular Properties Comparison") # Handle image selection image_selector.change( handle_image_selection, inputs=image_selector, outputs=[smiles_input, image_display], ) smiles_input.change( smiles_to_image, inputs=smiles_input, outputs=image_display ) # Generate button to display canonical SMILES and molecule image generate_button.click( lambda: ( gr.update(interactive=False), gr.update(interactive=False), ), outputs=[clear_button, generate_button], ).then( generate_canonical, inputs=smiles_input, outputs=[property_table, generated_output, gen_image_display], ).then( lambda: ( gr.update(interactive=True), gr.update(interactive=True), ), outputs=[clear_button, generate_button], ) clear_button.click( lambda: (None, None, None, None, None, None), outputs=[ smiles_input, image_display, image_selector, gen_image_display, generated_output, property_table, ], ) # Render with tabs gr.TabbedInterface( [introduction, property_prediction, molecule_generation], ["Introduction", "Property Prediction", "Molecule Generation"], ).launch(server_name="0.0.0.0", allowed_paths=["./"])