alimotahharynia's picture
Update app.py
8af0501 verified
import os
import json
import torch
import logging
import tempfile
import gradio as gr
from gradio.themes import Soft
from datasets import load_dataset
from transformers import AutoTokenizer, GPT2LMHeadModel
# Global logging setup
def setup_logging(output_file="app.log"):
log_filename = os.path.splitext(output_file)[0] + ".log"
logging.getLogger().handlers.clear()
file_handler = logging.FileHandler(log_filename)
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s"))
stream_handler = logging.StreamHandler()
stream_handler.setLevel(logging.INFO)
stream_handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s"))
logger = logging.getLogger()
logger.setLevel(logging.INFO)
logger.addHandler(file_handler)
logger.addHandler(stream_handler)
# Load model and tokenizer
def load_model_and_tokenizer(model_name):
logging.info(f"Loading model and tokenizer: {model_name}")
try:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)
if torch.cuda.is_available():
logging.info("Moving model to CUDA device.")
model = model.to("cuda")
return model, tokenizer
except Exception as e:
logging.error(f"Error loading model and tokenizer: {e}")
raise RuntimeError(f"Failed to load model and tokenizer: {e}")
# Load the dataset
def load_uniprot_dataset(dataset_name, dataset_key):
try:
dataset = load_dataset(dataset_name, dataset_key)
uniprot_to_sequence = {row["UniProt_id"]: row["Sequence"] for row in dataset["uniprot_seq"]}
logging.info("Dataset loaded and processed successfully.")
return uniprot_to_sequence
except Exception as e:
logging.error(f"Error loading dataset: {e}")
raise RuntimeError(f"Failed to load dataset: {e}")
def save_smiles_to_file(results):
file_path = os.path.join(tempfile.gettempdir(), "generated_smiles.json")
with open(file_path, "w") as f:
json.dump(results, f, indent=4)
return file_path
# SMILES Generator
class SMILESGenerator:
def __init__(self, model, tokenizer, uniprot_to_sequence):
self.model = model
self.tokenizer = tokenizer
self.uniprot_to_sequence = uniprot_to_sequence
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
self.generation_kwargs = {
"do_sample": True,
"top_k": 9,
"max_length": 1024,
"top_p": 0.9,
"num_return_sequences": 10,
"bos_token_id": tokenizer.bos_token_id,
"eos_token_id": tokenizer.eos_token_id,
"pad_token_id": tokenizer.pad_token_id
}
def generate_smiles(self, sequence, num_generated, progress_callback=None):
generated_smiles_set = set()
prompt = f"<|startoftext|><P>{sequence}<L>"
encoded_prompt = self.tokenizer(prompt, return_tensors="pt")["input_ids"].to(self.device)
logging.info(f"Generating SMILES for sequence: {sequence[:10]}...")
retries = 0
while len(generated_smiles_set) < num_generated:
if retries >= 30:
logging.warning("Max retries reached. Returning what has been generated so far.")
break
sample_outputs = self.model.generate(encoded_prompt, **self.generation_kwargs)
for i, sample_output in enumerate(sample_outputs):
output_decode = self.tokenizer.decode(sample_output, skip_special_tokens=False)
try:
generated_smiles = output_decode.split("<L>")[1].split("<|endoftext|>")[0]
if generated_smiles not in generated_smiles_set:
generated_smiles_set.add(generated_smiles)
except (IndexError, AttributeError) as e:
logging.warning(f"Failed to parse small molecule due to error: {str(e)}. Skipping.")
if progress_callback:
progress_callback((retries + 1) / 30)
retries += 1
logging.info(f"Small molecules generation completed. Generated {len(generated_smiles_set)} Small molecules.")
return list(generated_smiles_set)
# Gradio interface
def generate_smiles_gradio(sequence_input=None, uniprot_id=None, num_generated=10):
results = {}
uniprot_counter = 0 # Counter for sequences without UniProt IDs
# Process sequence inputs and include UniProt ID if found
if sequence_input:
sequences = [seq.strip() for seq in sequence_input.split(",") if seq.strip()]
for seq in sequences:
try:
# Find the corresponding UniProt ID for the sequence
uniprot_id_for_seq = [uid for uid, s in uniprot_to_sequence.items() if s == seq]
if uniprot_id_for_seq:
uniprot_id_for_seq = uniprot_id_for_seq[0]
else:
# Assign a number as the key for sequences without UniProt IDs
uniprot_id_for_seq = str(uniprot_counter)
uniprot_counter += 1
# Generate SMILES for the sequence
smiles = generator.generate_smiles(seq, num_generated)
# UniProt ID or the numeric key as the key
results[uniprot_id_for_seq] = {
"sequence": seq,
"smiles": smiles
}
except Exception as e:
results[str(uniprot_counter)] = {
"sequence": seq,
"error": f"Error generating small molecules: {str(e)}"
}
uniprot_counter += 1
# Process UniProt ID inputs and include sequence if found
if uniprot_id:
uniprot_ids = [uid.strip() for uid in uniprot_id.split(",") if uid.strip()]
for uid in uniprot_ids:
sequence = uniprot_to_sequence.get(uid, "N/A")
try:
# Generate SMILES for the sequence found
if sequence != "N/A":
smiles = generator.generate_smiles(sequence, num_generated)
results[uid] = {
"sequence": sequence,
"smiles": smiles
}
else:
results[uid] = {
"sequence": "N/A",
"error": f"UniProt ID {uid} not found in the dataset."
}
except Exception as e:
results[uid] = {
"sequence": "N/A",
"error": f"Error generating small molecules: {str(e)}"
}
# Check if no results were generated
if not results:
return {"error": "No small molecules generated. Please try again with different inputs."}
# Save results to a file
file_path = save_smiles_to_file(results)
# Return both results (JSON) and the file path
return results, file_path
# Main initialization and Gradio setup
if __name__ == "__main__":
setup_logging()
model_name = "alimotahharynia/DrugGen"
dataset_name = "alimotahharynia/approved_drug_target"
dataset_key = "uniprot_sequence"
model, tokenizer = load_model_and_tokenizer(model_name)
uniprot_to_sequence = load_uniprot_dataset(dataset_name, dataset_key)
generator = SMILESGenerator(model, tokenizer, uniprot_to_sequence)
# Apply Gradio theme
theme = Soft(primary_hue="indigo", secondary_hue="teal")
css = """
#app-title {
font-size: 24px;
text-align: center;
margin-bottom: 20px;
}
#description {
font-size: 18px;
text-align: center;
margin-bottom: 20px;
}
#file-output {
height: 90px;
}
#generate-button {
height: 40px;
color: #333333 !important;
}
"""
with gr.Blocks(theme=theme, css=css) as iface:
gr.Markdown("## DrugGen", elem_id="app-title")
gr.Markdown(
"Generate **drug-like small molecules structures** from protein sequences or UniProt IDs. "
"Input data, specify parameters, and download the results.",
elem_id="description"
)
with gr.Row():
sequence_input = gr.Textbox(
label="Protein Sequences",
placeholder="Enter sequences separated by commas (e.g., MGAASGRRGP..., MGETLGDSPI..., ...)",
lines=3,
)
uniprot_id_input = gr.Textbox(
label="UniProt IDs",
placeholder="Enter UniProt IDs separated by commas (e.g., P12821, P37231, ...)",
lines=3,
)
num_generated_slider = gr.Slider(
minimum=1,
maximum=100,
step=1,
value=10,
label="Number of Unique Small Molecules to Generate",
)
output = gr.JSON(label="Generated Small Molecules")
file_output = gr.File(
label="Download Results as JSON",
elem_id=["file-output"]
)
generate_button = gr.Button("Generate Small Molecule", elem_id="generate-button")
generate_button.click(
generate_smiles_gradio,
inputs=[sequence_input, uniprot_id_input, num_generated_slider],
outputs=[output, file_output]
)
gr.Markdown("""
### How to Cite:
If you use this tool in your research, please cite the following work:
```bibtex
@misc{sheikholeslami2024druggenadvancingdrugdiscovery,
title={DrugGen: Advancing Drug Discovery with Large Language Models and Reinforcement Learning Feedback},
author={Mahsa Sheikholeslami and Navid Mazrouei and Yousof Gheisari and Afshin Fasihi and Matin Irajpour and Ali Motahharynia},
year={2024},
eprint={2411.14157},
archivePrefix={arXiv},
primaryClass={q-bio.QM},
url={https://arxiv.org/abs/2411.14157},
}
```
This will help us maintain the tool and support future development!
""")
iface.launch(allowed_paths=["/tmp"])