|
"""Folding Studio Demo App.""" |
|
|
|
import logging |
|
|
|
import gradio as gr |
|
from folding_studio_data_models import FoldingModel |
|
from gradio_molecule3d import Molecule3D |
|
|
|
from folding_studio_demo.predict import predict |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
MOLECULE_REPS = [ |
|
{ |
|
"model": 0, |
|
"chain": "", |
|
"resname": "", |
|
"style": "cartoon", |
|
"color": "alphafold", |
|
|
|
"around": 0, |
|
"byres": False, |
|
|
|
|
|
} |
|
] |
|
|
|
DEFAULT_PROTEIN_SEQ = ">protein description\nMALWMRLLPLLALLALWGPDPAAA" |
|
|
|
MODEL_CHOICES = [ |
|
|
|
|
|
|
|
("Boltz-1", FoldingModel.BOLTZ), |
|
("Chai-1", FoldingModel.CHAI), |
|
("Protenix", FoldingModel.PROTENIX), |
|
] |
|
|
|
|
|
def sequence_input() -> gr.Textbox: |
|
"""Sequence input component. |
|
|
|
Returns: |
|
gr.Textbox: Sequence input component |
|
""" |
|
sequence = gr.Textbox( |
|
label="Protein Sequence", |
|
value=DEFAULT_PROTEIN_SEQ, |
|
lines=2, |
|
placeholder="Enter a protein sequence or upload a FASTA file", |
|
) |
|
file_input = gr.File( |
|
label="Upload a FASTA file", |
|
file_types=[".fasta", ".fa"], |
|
) |
|
|
|
def _process_file(file: gr.File | None) -> gr.Textbox: |
|
if file is None: |
|
return gr.Textbox() |
|
try: |
|
with open(file.name, "r") as f: |
|
content = f.read().strip() |
|
return gr.Textbox(value=content) |
|
except Exception as e: |
|
logger.error(f"Error reading file: {e}") |
|
return gr.Textbox() |
|
|
|
file_input.change(fn=_process_file, inputs=[file_input], outputs=[sequence]) |
|
return sequence |
|
|
|
|
|
def simple_prediction(api_key: str) -> None: |
|
"""Simple prediction tab. |
|
|
|
Args: |
|
api_key (str): Folding Studio API key |
|
""" |
|
gr.Markdown( |
|
""" |
|
### Predict a Protein Structure |
|
|
|
It will be run in the background and the results will be displayed in the output section. |
|
The output will contain the protein structure and the pLDDT plot. |
|
|
|
Select a model to run the inference with and enter a protein sequence or upload a FASTA file. |
|
""" |
|
) |
|
with gr.Row(): |
|
dropdown = gr.Dropdown( |
|
label="Model", |
|
choices=MODEL_CHOICES, |
|
scale=0, |
|
value=FoldingModel.BOLTZ, |
|
) |
|
with gr.Column(): |
|
sequence = sequence_input() |
|
|
|
predict_btn = gr.Button("Predict") |
|
|
|
with gr.Row(): |
|
mol_output = Molecule3D(label="Protein Structure", reps=MOLECULE_REPS) |
|
metrics_plot = gr.Plot(label="pLDDT") |
|
|
|
predict_btn.click( |
|
fn=predict, |
|
inputs=[sequence, api_key, dropdown], |
|
outputs=[mol_output, metrics_plot], |
|
) |
|
|
|
|
|
def model_comparison(api_key: str) -> None: |
|
"""Model comparison tab. |
|
|
|
Args: |
|
api_key (str): Folding Studio API key |
|
""" |
|
|
|
with gr.Row(): |
|
model = gr.Dropdown( |
|
label="Model", |
|
choices=MODEL_CHOICES, |
|
multiselect=True, |
|
scale=0, |
|
min_width=300, |
|
value=[FoldingModel.BOLTZ, FoldingModel.CHAI, FoldingModel.PROTENIX], |
|
) |
|
with gr.Column(): |
|
sequence = sequence_input() |
|
|
|
predict_btn = gr.Button("Compare Models") |
|
|
|
with gr.Row(): |
|
mol_output = Molecule3D(label="Protein Structure", reps=MOLECULE_REPS) |
|
metrics_plot = gr.Plot(label="pLDDT") |
|
|
|
predict_btn.click( |
|
fn=predict, |
|
inputs=[sequence, api_key, model], |
|
outputs=[mol_output, metrics_plot], |
|
) |
|
|
|
|
|
def __main__(): |
|
with gr.Blocks(title="Folding Studio Demo") as demo: |
|
gr.Markdown( |
|
""" |
|
# Folding Studio: Harness the Power of Protein Folding 𧬠|
|
|
|
Folding Studio is a platform for protein structure prediction. |
|
It uses the latest AI-powered folding models to predict the structure of a protein. |
|
|
|
Available models are : AlphaFold2, OpenFold, SoloSeq, Boltz-1, Chai and Protenix. |
|
|
|
## API Key |
|
To use the Folding Studio API, you need to provide an API key. |
|
You can get your API key by asking to the Folding Studio team. |
|
""" |
|
) |
|
api_key = gr.Textbox(label="Folding Studio API Key", type="password") |
|
gr.Markdown("## Demo Usage") |
|
with gr.Tab("π Simple Prediction"): |
|
simple_prediction(api_key) |
|
with gr.Tab("π Model Comparison"): |
|
model_comparison(api_key) |
|
|
|
demo.launch() |
|
|