Spaces:
Sleeping
Sleeping
File size: 3,700 Bytes
717452a e9bec21 717452a 4fe238f 717452a 1a28685 3a9dc6c 717452a 4fe238f 717452a d117130 717452a 7a8369a 4fe238f 717452a 4fe238f 717452a 4fe238f 717452a 0ed057c 717452a 4fe238f 0ed057c 4fe238f 717452a 4fe238f 717452a 4fe238f 717452a e9bec21 4fe238f e9bec21 717452a e9bec21 4fe238f e9bec21 717452a 4fe238f 717452a 4fe238f 717452a 4fe238f a079f79 4fe238f a079f79 4fe238f 717452a 4fe238f 717452a 4fe238f 717452a 4fe238f 717452a 4fe238f e9bec21 717452a 4fe238f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import os
from threading import Thread, Event
from typing import Iterator
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, GemmaTokenizerFast, TextIteratorStreamer
DESCRIPTION = """\
# Monlam LLM v2.0.1 -Translation
## This version first generates detailed reasoning (thoughts) and then, after the marker #Final Translation, the translation is produced.
"""
# Constants
path = "TenzinGayche/tpo_v1.0.0_dpo_2_3ep_ft"
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
# Load the model and tokenizer
tokenizer = GemmaTokenizerFast.from_pretrained(path)
model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.float16).to("cuda")
model.config.sliding_window = 4096
model.eval()
model.config.use_cache = True
# Shared stop event
stop_event = Event()
# Generate function
def generate(message: str,
show_thoughts: bool,
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
do_sample: bool = False,
) -> Iterator[str]:
stop_event.clear()
message=message.replace('\n',' ')
# Prepare input for the model
conversation = [
{"role": "user", "content": f"Please translate the following into English: {message} Translation:"}
]
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Input trimmed as it exceeded {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
# Use a streamer to get generated text
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids=input_ids,
streamer=streamer,
max_new_tokens=max_new_tokens,
)
# Generate in a separate thread
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
in_translation = False
for text in streamer:
if stop_event.is_set():
break
# Process the generated text
if "#Final Translation:" in text and not in_translation:
in_translation = True
if not show_thoughts:
text = text.split("#Final Translation:", 1)[1].strip() # Skip reasoning if "View Thoughts" is disabled
if in_translation:
outputs.append(text)
yield "".join(outputs)
elif show_thoughts:
outputs.append(text)
yield "".join(outputs)
# Append assistant's response
chat_history = "".join(outputs)
# Stop generation function
def stop_generation():
stop_event.set()
# Create the Gradio interface
with gr.Blocks(css="style.css", fill_height=True) as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
input_text = gr.Textbox(label="Enter Tibetan text", placeholder="Type Tibetan text here...")
show_thoughts = gr.Checkbox(label="View Detailed Thoughts", value=True)
submit_button = gr.Button("Translate")
stop_button = gr.Button("Stop")
with gr.Row():
output_area = gr.Textbox(
label="Output (Thoughts and Translation)",
lines=20,
interactive=False,
)
# Connect buttons to functions
submit_button.click(
fn=generate,
inputs=[input_text, show_thoughts],
outputs=output_area,
queue=True, # Enable streaming
)
stop_button.click(stop_generation)
if __name__ == "__main__":
demo.queue(max_size=20).launch(share=True) |