translation / app.py
TenzinGayche's picture
Update app.py
3a9dc6c verified
raw
history blame
3.7 kB
import os
from threading import Thread, Event
from typing import Iterator
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, GemmaTokenizerFast, TextIteratorStreamer
DESCRIPTION = """\
# Monlam LLM v2.0.1 -Translation
## This version first generates detailed reasoning (thoughts) and then, after the marker #Final Translation, the translation is produced.
"""
# Constants
path = "TenzinGayche/tpo_v1.0.0_dpo_2_3ep_ft"
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
# Load the model and tokenizer
tokenizer = GemmaTokenizerFast.from_pretrained(path)
model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.float16).to("cuda")
model.config.sliding_window = 4096
model.eval()
model.config.use_cache = True
# Shared stop event
stop_event = Event()
# Generate function
def generate(message: str,
show_thoughts: bool,
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
do_sample: bool = False,
) -> Iterator[str]:
stop_event.clear()
message=message.replace('\n',' ')
# Prepare input for the model
conversation = [
{"role": "user", "content": f"Please translate the following into English: {message} Translation:"}
]
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Input trimmed as it exceeded {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
# Use a streamer to get generated text
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids=input_ids,
streamer=streamer,
max_new_tokens=max_new_tokens,
)
# Generate in a separate thread
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
in_translation = False
for text in streamer:
if stop_event.is_set():
break
# Process the generated text
if "#Final Translation:" in text and not in_translation:
in_translation = True
if not show_thoughts:
text = text.split("#Final Translation:", 1)[1].strip() # Skip reasoning if "View Thoughts" is disabled
if in_translation:
outputs.append(text)
yield "".join(outputs)
elif show_thoughts:
outputs.append(text)
yield "".join(outputs)
# Append assistant's response
chat_history = "".join(outputs)
# Stop generation function
def stop_generation():
stop_event.set()
# Create the Gradio interface
with gr.Blocks(css="style.css", fill_height=True) as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
input_text = gr.Textbox(label="Enter Tibetan text", placeholder="Type Tibetan text here...")
show_thoughts = gr.Checkbox(label="View Detailed Thoughts", value=True)
submit_button = gr.Button("Translate")
stop_button = gr.Button("Stop")
with gr.Row():
output_area = gr.Textbox(
label="Output (Thoughts and Translation)",
lines=20,
interactive=False,
)
# Connect buttons to functions
submit_button.click(
fn=generate,
inputs=[input_text, show_thoughts],
outputs=output_area,
queue=True, # Enable streaming
)
stop_button.click(stop_generation)
if __name__ == "__main__":
demo.queue(max_size=20).launch(share=True)