Esm2Text / app.py
habdine's picture
Update app.py
c5d0e06 verified
raw
history blame
3.39 kB
import os
from threading import Thread
from typing import Iterator
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
DESCRIPTION = """\
# ESM2Text Demo
A demo to generate a protein's funtion with its amino acid sequence using [ESM2Text Base v1.1](https://huggingface.co/habdine/Esm2Text-Base-v1-1). To test this model, only enter below, the amino acid sequence of the protein without any spaces.
"""
MAX_MAX_NEW_TOKENS = 256
DEFAULT_MAX_NEW_TOKENS = 100
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained('habdine/Esm2Text-Base-v1-1',
trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained('habdine/Esm2Text-Base-v1-1',
trust_remote_code=True).to(device)
model.eval()
@spaces.GPU(duration=90)
def generate(
message: str,
chat_history: list[dict],
max_new_tokens: int = 1024,
do_sample: bool = False,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
protein_sequence=message,
tokenizer=tokenizer,
device=device,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=model.generate_protein_description, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
chat_interface = gr.ChatInterface(
fn=generate,
additional_inputs=[
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
),
gr.Checkbox(label="Do Sample"),
gr.Slider(
label="Temperature",
minimum=0.1,
maximum=4.0,
step=0.1,
value=0.6,
),
gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.9,
),
gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=50,
),
gr.Slider(
label="Repetition penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.0,
),
],
stop_btn=None,
examples=[
['AEQAERYEEMVEFMEKL'],
["MAVVLPAVVEELLSEMAAAVQESARIPDEYLLSLKFLFGSSATQALDLVDRQSITLISSPSGRRVYQVLGSSSKTYTCLASCHYCSCPAFAFSVLRKSDSILCKHLLAVYLSQVMRTCQQLSVSDKQLTDILLMEKKQEA"],
],
cache_examples=False,
type="messages",
)
with gr.Blocks(css_paths="style.css", fill_height=True) as demo:
gr.Markdown(DESCRIPTION)
gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
chat_interface.render()
if __name__ == "__main__":
demo.queue(max_size=20).launch()