import os from threading import Thread, Event from typing import Iterator import gradio as gr import torch from transformers import AutoModelForCausalLM, GemmaTokenizerFast, TextIteratorStreamer DESCRIPTION = """\ # Monlam LLM v2.0.1 -Translation ## This version first generates detailed reasoning (thoughts) and then, after the marker #Final Translation, the translation is produced. """ # Constants path = "TenzinGayche/tpo_v1.0.0_dpo_2_3ep_ft" MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) # Load the model and tokenizer tokenizer = GemmaTokenizerFast.from_pretrained(path) model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.float16).to("cuda") model.config.sliding_window = 4096 model.eval() model.config.use_cache = True # Shared stop event stop_event = Event() # Generate function def generate(message: str, show_thoughts: bool, max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2, do_sample: bool = False, ) -> Iterator[str]: stop_event.clear() message=message.replace('\n',' ') # Prepare input for the model conversation = [ {"role": "user", "content": f"Please translate the following into English: {message} Translation:"} ] input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt") if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] gr.Warning(f"Input trimmed as it exceeded {MAX_INPUT_TOKEN_LENGTH} tokens.") input_ids = input_ids.to(model.device) # Use a streamer to get generated text streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( input_ids=input_ids, streamer=streamer, max_new_tokens=max_new_tokens, ) # Generate in a separate thread t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() outputs = [] in_translation = False for text in streamer: if stop_event.is_set(): break # Process the generated text if "#Final Translation:" in text and not in_translation: in_translation = True if not show_thoughts: text = text.split("#Final Translation:", 1)[1].strip() # Skip reasoning if "View Thoughts" is disabled if in_translation: outputs.append(text) yield "".join(outputs) elif show_thoughts: outputs.append(text) yield "".join(outputs) # Append assistant's response chat_history = "".join(outputs) # Stop generation function def stop_generation(): stop_event.set() # Create the Gradio interface with gr.Blocks(css="style.css", fill_height=True) as demo: gr.Markdown(DESCRIPTION) with gr.Row(): input_text = gr.Textbox(label="Enter Tibetan text", placeholder="Type Tibetan text here...") show_thoughts = gr.Checkbox(label="View Detailed Thoughts", value=True) submit_button = gr.Button("Translate") stop_button = gr.Button("Stop") with gr.Row(): output_area = gr.Textbox( label="Output (Thoughts and Translation)", lines=20, interactive=False, ) # Connect buttons to functions submit_button.click( fn=generate, inputs=[input_text, show_thoughts], outputs=output_area, queue=True, # Enable streaming ) stop_button.click(stop_generation) if __name__ == "__main__": demo.queue(max_size=20).launch(share=True)