alimotahharynia commited on
Commit
dc3d3fa
1 Parent(s): ca57160

Add Gradio app

Browse files
Files changed (1) hide show
  1. app.py +197 -0
app.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import logging
5
+ import tempfile
6
+ import gradio as gr
7
+ from datasets import load_dataset
8
+ from transformers import AutoTokenizer, GPT2LMHeadModel
9
+
10
+ # Global logging setup
11
+ def setup_logging(output_file="app.log"):
12
+ log_filename = os.path.splitext(output_file)[0] + ".log"
13
+ logging.getLogger().handlers.clear()
14
+ file_handler = logging.FileHandler(log_filename)
15
+ file_handler.setLevel(logging.INFO)
16
+ file_handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s"))
17
+
18
+ stream_handler = logging.StreamHandler()
19
+ stream_handler.setLevel(logging.INFO)
20
+ stream_handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s"))
21
+
22
+ logger = logging.getLogger()
23
+ logger.setLevel(logging.INFO)
24
+ logger.addHandler(file_handler)
25
+ logger.addHandler(stream_handler)
26
+
27
+ # Load model and tokenizer
28
+ def load_model_and_tokenizer(model_name):
29
+ logging.info(f"Loading model and tokenizer: {model_name}")
30
+ try:
31
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
32
+ model = GPT2LMHeadModel.from_pretrained(model_name)
33
+ if torch.cuda.is_available():
34
+ logging.info("Moving model to CUDA device.")
35
+ model = model.to("cuda")
36
+ return model, tokenizer
37
+ except Exception as e:
38
+ logging.error(f"Error loading model and tokenizer: {e}")
39
+ raise RuntimeError(f"Failed to load model and tokenizer: {e}")
40
+
41
+ # Load the dataset
42
+ def load_uniprot_dataset(dataset_name, dataset_key):
43
+ try:
44
+ dataset = load_dataset(dataset_name, dataset_key)
45
+ uniprot_to_sequence = {row["UniProt_id"]: row["Sequence"] for row in dataset["uniprot_seq"]}
46
+ logging.info("Dataset loaded and processed successfully.")
47
+ return uniprot_to_sequence
48
+ except Exception as e:
49
+ logging.error(f"Error loading dataset: {e}")
50
+ raise RuntimeError(f"Failed to load dataset: {e}")
51
+
52
+ # SMILES Generator
53
+ class SMILESGenerator:
54
+ def __init__(self, model, tokenizer, uniprot_to_sequence):
55
+ self.model = model
56
+ self.tokenizer = tokenizer
57
+ self.uniprot_to_sequence = uniprot_to_sequence
58
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
59
+ self.model.to(self.device)
60
+ self.generation_kwargs = {
61
+ "do_sample": True,
62
+ "top_k": 9,
63
+ "max_length": 1024,
64
+ "top_p": 0.9,
65
+ "num_return_sequences": 10,
66
+ "bos_token_id": tokenizer.bos_token_id,
67
+ "eos_token_id": tokenizer.eos_token_id,
68
+ "pad_token_id": tokenizer.pad_token_id
69
+ }
70
+
71
+ def generate_smiles(self, sequence, num_generated, progress_callback=None):
72
+ generated_smiles_set = set()
73
+ prompt = f"<|startoftext|><P>{sequence}<L>"
74
+ encoded_prompt = self.tokenizer(prompt, return_tensors="pt")["input_ids"].to(self.device)
75
+
76
+ logging.info(f"Generating SMILES for sequence: {sequence[:10]}...")
77
+ retries = 0
78
+ while len(generated_smiles_set) < num_generated:
79
+ if retries >= 30:
80
+ logging.warning("Max retries reached. Returning what has been generated so far.")
81
+ break
82
+
83
+ sample_outputs = self.model.generate(encoded_prompt, **self.generation_kwargs)
84
+ for i, sample_output in enumerate(sample_outputs):
85
+ output_decode = self.tokenizer.decode(sample_output, skip_special_tokens=False)
86
+ try:
87
+ generated_smiles = output_decode.split("<L>")[1].split("<|endoftext|>")[0]
88
+ if generated_smiles not in generated_smiles_set:
89
+ generated_smiles_set.add(generated_smiles)
90
+ except (IndexError, AttributeError) as e:
91
+ logging.warning(f"Failed to parse SMILES due to error: {str(e)}. Skipping.")
92
+
93
+ if progress_callback:
94
+ progress_callback((retries + 1) / 30)
95
+
96
+ retries += 1
97
+
98
+ logging.info(f"SMILES generation completed. Generated {len(generated_smiles_set)} SMILES.")
99
+ return list(generated_smiles_set)
100
+
101
+ # Gradio interface
102
+ def generate_smiles_gradio(sequence_input=None, uniprot_id=None, num_generated=10):
103
+ results = {}
104
+
105
+ # Process sequence inputs and include UniProt ID if found
106
+ if sequence_input:
107
+ sequences = [seq.strip() for seq in sequence_input.split(",") if seq.strip()]
108
+ for seq in sequences:
109
+ try:
110
+ # Find the corresponding UniProt ID for the sequence
111
+ uniprot_id_for_seq = [uid for uid, s in uniprot_to_sequence.items() if s == seq]
112
+ uniprot_id_for_seq = uniprot_id_for_seq[0] if uniprot_id_for_seq else "N/A"
113
+
114
+ # Generate SMILES for the sequence
115
+ smiles = generator.generate_smiles(seq, num_generated)
116
+ results[uniprot_id_for_seq] = {
117
+ "sequence": seq,
118
+ "smiles": smiles
119
+ }
120
+ except Exception as e:
121
+ results["N/A"] = {"sequence": seq, "error": f"Error generating SMILES: {str(e)}"}
122
+
123
+ # Process UniProt ID inputs and include sequence if found
124
+ if uniprot_id:
125
+ uniprot_ids = [uid.strip() for uid in uniprot_id.split(",") if uid.strip()]
126
+ for uid in uniprot_ids:
127
+ sequence = uniprot_to_sequence.get(uid, "N/A")
128
+ try:
129
+ # Generate SMILES for the sequence found
130
+ if sequence != "N/A":
131
+ smiles = generator.generate_smiles(sequence, num_generated)
132
+ results[uid] = {
133
+ "sequence": sequence,
134
+ "smiles": smiles
135
+ }
136
+ else:
137
+ results[uid] = {
138
+ "sequence": "N/A",
139
+ "error": f"UniProt ID {uid} not found in the dataset."
140
+ }
141
+ except Exception as e:
142
+ results[uid] = {"sequence": "N/A", "error": f"Error generating SMILES: {str(e)}"}
143
+
144
+ # Check if no results were generated
145
+ if not results:
146
+ return {"error": "No SMILES generated. Please try again with different inputs."}
147
+
148
+ # Save results to a file
149
+ file_path = save_smiles_to_file(results)
150
+ return results, file_path
151
+
152
+
153
+ def save_smiles_to_file(results):
154
+ file_path = os.path.join(tempfile.gettempdir(), "generated_smiles.json")
155
+ with open(file_path, "w") as f:
156
+ json.dump(results, f, indent=4)
157
+ return file_path
158
+
159
+
160
+ # Main initialization and Gradio setup
161
+ if __name__ == "__main__":
162
+ setup_logging()
163
+ model_name = "alimotahharynia/DrugGen"
164
+ dataset_name = "alimotahharynia/approved_drug_target"
165
+ dataset_key = "uniprot_sequence"
166
+
167
+ # Load model, tokenizer, and dataset
168
+ model, tokenizer = load_model_and_tokenizer(model_name)
169
+ uniprot_to_sequence = load_uniprot_dataset(dataset_name, dataset_key)
170
+
171
+ # SMILESGenerator
172
+ generator = SMILESGenerator(model, tokenizer, uniprot_to_sequence)
173
+
174
+ # Gradio interface
175
+ with gr.Blocks() as iface:
176
+ gr.Markdown("## DrugGen interface")
177
+ with gr.Row():
178
+ sequence_input = gr.Textbox(
179
+ label="Input Protein Sequences",
180
+ placeholder="Enter protein sequences separated by commas..."
181
+ )
182
+ uniprot_id_input = gr.Textbox(
183
+ label="UniProt IDs",
184
+ placeholder="Enter UniProt IDs separated by commas..."
185
+ )
186
+ num_generated_slider = gr.Slider(minimum=1, maximum=100, step=1, value=10, label="Number of Unique SMILES to Generate")
187
+ output = gr.JSON(label="Generated SMILES")
188
+ file_output = gr.File(label="Download output as .json")
189
+
190
+ generate_button = gr.Button("Generate SMILES")
191
+ generate_button.click(
192
+ generate_smiles_gradio,
193
+ inputs=[sequence_input, uniprot_id_input, num_generated_slider],
194
+ outputs=[output, file_output]
195
+ )
196
+
197
+ iface.launch()