File size: 3,700 Bytes
717452a
 
 
 
e9bec21
 
717452a
4fe238f
717452a
1a28685
3a9dc6c
 
 
717452a
4fe238f
 
 
717452a
 
 
d117130
 
717452a
 
 
7a8369a
4fe238f
717452a
 
4fe238f
 
 
 
 
717452a
 
 
 
4fe238f
717452a
 
0ed057c
717452a
4fe238f
 
0ed057c
4fe238f
717452a
 
 
4fe238f
717452a
 
4fe238f
717452a
e9bec21
4fe238f
e9bec21
717452a
e9bec21
4fe238f
 
e9bec21
 
 
717452a
4fe238f
 
717452a
 
4fe238f
717452a
4fe238f
 
 
 
 
a079f79
4fe238f
 
 
 
 
 
a079f79
4fe238f
 
 
 
 
717452a
 
 
4fe238f
 
717452a
 
 
4fe238f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
717452a
4fe238f
 
 
717452a
4fe238f
e9bec21
717452a
4fe238f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
from threading import Thread, Event
from typing import Iterator

import gradio as gr
import torch
from transformers import AutoModelForCausalLM, GemmaTokenizerFast, TextIteratorStreamer

DESCRIPTION = """\
# Monlam LLM v2.0.1 -Translation

## This version first generates detailed reasoning (thoughts) and then, after the marker #Final Translation, the translation is produced.

"""

# Constants
path = "TenzinGayche/tpo_v1.0.0_dpo_2_3ep_ft"
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))

# Load the model and tokenizer
tokenizer = GemmaTokenizerFast.from_pretrained(path)
model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.float16).to("cuda")

model.config.sliding_window = 4096
model.eval()
model.config.use_cache = True
# Shared stop event
stop_event = Event()


# Generate function
def generate(message: str,
    show_thoughts: bool,
    max_new_tokens: int = 1024,
    temperature: float = 0.6,
    top_p: float = 0.9,
    top_k: int = 50,
    repetition_penalty: float = 1.2,
    do_sample: bool = False,
) -> Iterator[str]:
    stop_event.clear()
    message=message.replace('\n',' ')

    # Prepare input for the model
    conversation = [
        {"role": "user", "content": f"Please translate the following into English: {message} Translation:"}
    ]
    input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning(f"Input trimmed as it exceeded {MAX_INPUT_TOKEN_LENGTH} tokens.")
    input_ids = input_ids.to(model.device)

    # Use a streamer to get generated text
    streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        input_ids=input_ids,
        streamer=streamer,
        max_new_tokens=max_new_tokens,
    )

    # Generate in a separate thread
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    in_translation = False

    for text in streamer:
        if stop_event.is_set():
            break

        # Process the generated text
        if "#Final Translation:" in text and not in_translation:
            in_translation = True
            if not show_thoughts:
                text = text.split("#Final Translation:", 1)[1].strip()  # Skip reasoning if "View Thoughts" is disabled

        if in_translation:
            outputs.append(text)
            yield "".join(outputs)
        elif show_thoughts:
            outputs.append(text)
            yield "".join(outputs)

    # Append assistant's response
    chat_history = "".join(outputs)


# Stop generation function
def stop_generation():
    stop_event.set()


# Create the Gradio interface
with gr.Blocks(css="style.css", fill_height=True) as demo:
    gr.Markdown(DESCRIPTION)

    with gr.Row():
        input_text = gr.Textbox(label="Enter Tibetan text", placeholder="Type Tibetan text here...")
        show_thoughts = gr.Checkbox(label="View Detailed Thoughts", value=True)
        submit_button = gr.Button("Translate")
        stop_button = gr.Button("Stop")

    with gr.Row():
        output_area = gr.Textbox(
            label="Output (Thoughts and Translation)",
            lines=20,
            interactive=False,
        )

    # Connect buttons to functions
    submit_button.click(
        fn=generate,
        inputs=[input_text, show_thoughts],
        outputs=output_area,
        queue=True,  # Enable streaming
    )
    stop_button.click(stop_generation)

if __name__ == "__main__":
    demo.queue(max_size=20).launch(share=True)