TenzinGayche commited on
Commit
fcc0568
·
verified ·
1 Parent(s): 3a9dc6c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -78
app.py CHANGED
@@ -1,122 +1,162 @@
1
  import os
2
  from threading import Thread, Event
3
  from typing import Iterator
4
-
5
  import gradio as gr
6
  import torch
7
  from transformers import AutoModelForCausalLM, GemmaTokenizerFast, TextIteratorStreamer
 
8
 
9
- DESCRIPTION = """\
10
- # Monlam LLM v2.0.1 -Translation
11
-
12
- ## This version first generates detailed reasoning (thoughts) and then, after the marker #Final Translation, the translation is produced.
13
-
14
- """
15
-
16
- # Constants
17
  path = "TenzinGayche/tpo_v1.0.0_dpo_2_3ep_ft"
18
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
19
 
20
  # Load the model and tokenizer
21
  tokenizer = GemmaTokenizerFast.from_pretrained(path)
22
  model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.float16).to("cuda")
23
-
24
  model.config.sliding_window = 4096
25
  model.eval()
26
  model.config.use_cache = True
27
- # Shared stop event
28
- stop_event = Event()
29
 
 
30
 
31
- # Generate function
32
- def generate(message: str,
33
- show_thoughts: bool,
34
- max_new_tokens: int = 1024,
35
- temperature: float = 0.6,
36
- top_p: float = 0.9,
37
- top_k: int = 50,
38
- repetition_penalty: float = 1.2,
39
- do_sample: bool = False,
40
- ) -> Iterator[str]:
41
  stop_event.clear()
42
- message=message.replace('\n',' ')
43
-
 
 
 
 
 
 
 
 
44
  # Prepare input for the model
45
  conversation = [
46
  {"role": "user", "content": f"Please translate the following into English: {message} Translation:"}
47
  ]
48
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
 
49
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
50
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
51
  gr.Warning(f"Input trimmed as it exceeded {MAX_INPUT_TOKEN_LENGTH} tokens.")
 
52
  input_ids = input_ids.to(model.device)
53
-
54
- # Use a streamer to get generated text
55
  streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
 
 
56
  generate_kwargs = dict(
57
  input_ids=input_ids,
58
  streamer=streamer,
59
- max_new_tokens=max_new_tokens,
60
  )
61
-
62
- # Generate in a separate thread
63
- t = Thread(target=model.generate, kwargs=generate_kwargs)
64
- t.start()
65
-
66
- outputs = []
 
67
  in_translation = False
68
-
 
 
 
 
 
 
 
 
 
69
  for text in streamer:
70
- if stop_event.is_set():
71
- break
72
-
73
- # Process the generated text
74
- if "#Final Translation:" in text and not in_translation:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  in_translation = True
76
- if not show_thoughts:
77
- text = text.split("#Final Translation:", 1)[1].strip() # Skip reasoning if "View Thoughts" is disabled
78
-
 
 
 
 
 
 
79
  if in_translation:
80
- outputs.append(text)
81
- yield "".join(outputs)
82
- elif show_thoughts:
83
- outputs.append(text)
84
- yield "".join(outputs)
85
-
86
- # Append assistant's response
87
- chat_history = "".join(outputs)
88
-
89
-
90
- # Stop generation function
91
- def stop_generation():
92
- stop_event.set()
93
-
94
-
95
- # Create the Gradio interface
96
- with gr.Blocks(css="style.css", fill_height=True) as demo:
97
- gr.Markdown(DESCRIPTION)
98
-
99
- with gr.Row():
100
- input_text = gr.Textbox(label="Enter Tibetan text", placeholder="Type Tibetan text here...")
101
- show_thoughts = gr.Checkbox(label="View Detailed Thoughts", value=True)
102
- submit_button = gr.Button("Translate")
103
- stop_button = gr.Button("Stop")
104
-
 
 
 
105
  with gr.Row():
106
- output_area = gr.Textbox(
107
- label="Output (Thoughts and Translation)",
108
- lines=20,
109
- interactive=False,
 
110
  )
 
 
 
 
 
 
 
 
 
111
 
112
- # Connect buttons to functions
113
- submit_button.click(
114
- fn=generate,
115
- inputs=[input_text, show_thoughts],
116
- outputs=output_area,
117
- queue=True, # Enable streaming
118
  )
119
- stop_button.click(stop_generation)
120
 
 
 
 
 
 
 
121
  if __name__ == "__main__":
122
  demo.queue(max_size=20).launch(share=True)
 
1
  import os
2
  from threading import Thread, Event
3
  from typing import Iterator
 
4
  import gradio as gr
5
  import torch
6
  from transformers import AutoModelForCausalLM, GemmaTokenizerFast, TextIteratorStreamer
7
+ from gradio import ChatMessage
8
 
9
+ # Constants and model initialization
 
 
 
 
 
 
 
10
  path = "TenzinGayche/tpo_v1.0.0_dpo_2_3ep_ft"
11
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
12
 
13
  # Load the model and tokenizer
14
  tokenizer = GemmaTokenizerFast.from_pretrained(path)
15
  model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.float16).to("cuda")
 
16
  model.config.sliding_window = 4096
17
  model.eval()
18
  model.config.use_cache = True
 
 
19
 
20
+ stop_event = Event()
21
 
22
+ def stream_translation(user_message: str, messages: list) -> Iterator[list]:
 
 
 
 
 
 
 
 
 
23
  stop_event.clear()
24
+ message = user_message.replace('\n', ' ')
25
+
26
+ # Initialize the chat history if empty
27
+ if not messages:
28
+ messages = []
29
+
30
+ # Add user message if not already present
31
+ if not messages or (isinstance(messages[-1], dict) and messages[-1]["role"] != "user"):
32
+ messages.append({"role": "user", "content": message})
33
+
34
  # Prepare input for the model
35
  conversation = [
36
  {"role": "user", "content": f"Please translate the following into English: {message} Translation:"}
37
  ]
38
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
39
+
40
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
41
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
42
  gr.Warning(f"Input trimmed as it exceeded {MAX_INPUT_TOKEN_LENGTH} tokens.")
43
+
44
  input_ids = input_ids.to(model.device)
 
 
45
  streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
46
+
47
+ # Generation parameters
48
  generate_kwargs = dict(
49
  input_ids=input_ids,
50
  streamer=streamer,
51
+ max_new_tokens=2000,
52
  )
53
+
54
+ # Start generation in a separate thread
55
+ Thread(target=model.generate, kwargs=generate_kwargs).start()
56
+
57
+ # Initialize response tracking
58
+ thought_buffer = ""
59
+ translation_buffer = ""
60
  in_translation = False
61
+ accumulated_text = ""
62
+
63
+ # Add initial thinking message
64
+ messages.append({
65
+ "role": "assistant",
66
+ "content": "",
67
+ "metadata": {"title": "🤔 Thinking...", "status": "pending"}
68
+ })
69
+ yield messages
70
+
71
  for text in streamer:
72
+ accumulated_text += text
73
+
74
+ # Check for the marker in the accumulated text
75
+ if "#Final Translation:" in accumulated_text and not in_translation:
76
+ # Split at the marker and handle both parts
77
+ parts = accumulated_text.split("#Final Translation:", 1)
78
+ thought_buffer = parts[0].strip()
79
+ translation_start = parts[1] if len(parts) > 1 else ""
80
+
81
+ # Complete the thinking phase
82
+ messages[-1] = {
83
+ "role": "assistant",
84
+ "content": thought_buffer,
85
+ "metadata": {"title": "🤔 Thought Process", "status": "done"},
86
+ "collapsed": True
87
+ }
88
+ yield messages
89
+ thought_buffer=""
90
+
91
+ # Start translation phase as a normal message
92
  in_translation = True
93
+ messages.append({
94
+ "role": "assistant",
95
+ "content": translation_start.strip() # No metadata for normal response
96
+ })
97
+ translation_buffer = translation_start
98
+ yield messages
99
+
100
+ continue
101
+
102
  if in_translation:
103
+ translation_buffer += text
104
+ messages[-1] = {
105
+ "role": "assistant",
106
+ "content": translation_buffer.strip() # No metadata for normal response
107
+ }
108
+ else:
109
+ thought_buffer += text
110
+ messages[-1] = {
111
+ "role": "assistant",
112
+ "content": thought_buffer.strip(),
113
+ "metadata": {"title": "🤔 Thinking...", "status": "pending"}
114
+ }
115
+
116
+ yield messages
117
+
118
+ with gr.Blocks(title="Monlam Translation System", css="""
119
+ /* ... (keep existing CSS) */
120
+ """, theme=gr.themes.Soft()) as demo:
121
+ gr.Markdown("# 💭 Samloe Melong Translate")
122
+ gr.Markdown("It's a proof of concept. The model first generates detailed reasoning and then provides the translation. It only works for Tibetan to English (for now)!!")
123
+
124
+ chatbot = gr.Chatbot(
125
+ type="messages",
126
+ show_label=False,
127
+ render_markdown=True,
128
+ height=400
129
+ )
130
+
131
  with gr.Row():
132
+ input_box = gr.Textbox(
133
+ lines=3,
134
+ label="Enter Tibetan text",
135
+ placeholder="Type Tibetan text here...",
136
+ show_label=True,
137
  )
138
+ submit_btn = gr.Button("Translate", variant="primary", scale=0.15)
139
+
140
+ # Add example section AFTER defining input_box
141
+ examples = [
142
+ ["རྟག་པར་མངོན་ཞེན་གྱིས་བསླང་བའམ། །གཉེན་པོ་ཡིས་ནི་བསླང་བ་ཉིད། །ཡོན་ཏན་དང་ནི་ཕན་འདོགས་ཞིང་། །སྡུག་བསྔལ་བ་ལ་དགེ་ཆེན་འགྱུར། "],
143
+ ["ད་ཆ་ཨ་རིའི་ཚོང་རའི་ནང་དུ་གླེང་གཞི་ཤུགས་ཆེར་འགྱུར་བཞིན་པའི་ Deep Seek ཞེས་རྒྱ་ནག་གི་མི་བཟོས་རིག་ནུས་མཉེན་ཆས་དེས་བོད་ནང་དུ་དེ་སྔ་ནས་དམ་དྲག་ཤུགས་ཆེ་ཡོད་པའི་ཐོག་ད་དུང་ཤུགས་ཆེ་རུ་གཏོང་སྲིད་པ་གསུངས་སོང་།"],
144
+ ["མཉེན་ཆས་འདི་བཞིན་ཨ་རི་དང་རྒྱ་ནག་གཉིས་དབར་ཚོང་འབྲེལ་བཀག་སྡོམ་གྱི་གནད་དོན་ཁྲོད་ཨ་རིའི་མི་བཟོས་རིག་ནུས་ཀྱི་ Chips ཅིབ་སེ་མ་ལག་རྒྱ་ནག་ནང་དུ་ཚོང་འགྲེམ་བཀག་སྡོམ་བྱས་མིན་ལ་མ་ལྟོས་པར། ཚོང་འབྲེལ་བཀག་སྡོམ་གྱི་སྔོན་ཚུད་ནས་རྒྱ་ནག་གི་ཉོ་ཚོང་བྱས་པའི་ཅིབ་སེ་མ་ལག་དོན་ཐེངས་རྙིང་པའི་ཐོག་བཟོ་བསྐྲུན་བྱས་པ་དང་། ཨ་སྒོར་ཐེར་འབུམ་མང་པོའི་འགྲོ་གྲོན་ཐོག་བཟོ་བསྐྲུན་བྱས་པའི་ AI འམ་མི་བཟོས་རིག་ནུས་ཀྱི་མཉེན་ཆས་གཞན་དང་མི་འདྲ་བར་ Deep seek མཉེན་ཆས་དེ་བཞིན་ཨ་སྒོར་ས་ཡ་ ༦ ཁོ་ནའི་འགྲོ་གྲོན་ཐོག་བཟོ་བསྐྲུན་བྱས་པའི་གནད་དོན་སོགས་ཀྱི་རྐྱེན་པས་ཁ་སང་ཨ་རིའི་ཚོང་རའི་ནང་དུ་མི་བཟོས་རིག་ནུས་མཉེན་ཆས་འཕྲུལ་རིག་གི་ Chips ཅིབ་སེ་མ་ལག་བཟོ་བསྐྲུན་བྱས་མཁན་ NVidia ལྟ་བུར་ཨ་སྒོར་ཐེར་འབུམ་ ༦ མིན་ཙམ་གྱི་གྱོན་རྒུད་ཕོག་པའི་གནས་ཚུལ་བྱུང་ཡོད་འདུག"],
145
+ ["དེ་ཡང་དེ་རིང་ BBC དང་ Reuters སོགས་རྒྱལ་སྤྱིའི་བརྒྱུད་ལམ་ཁག་གི་གནས་ཚུ���་སྤེལ་བར་གཞིགས་ན། རྒྱ་ནག་གི་ Huangzhou གྲོང་ཁྱེར་ནང་དུ་བཟོ་བསྐྲུན་བྱས་པའི་ Deep Seek མི་བཟོས་རིག་ནུས་མཉེན་ཆས་དེ་བཞིན་ ChatGPT དང་ Gemini སོགས་མི་བཟོས་རིག་ནུས་ཀྱི་མཉེན་ཆས་གཞན་དང་བསྡུར་ན་མགྱོགས་ཚད་དང་ནུས་པའི་ཆ་ནས་གཅིག་མཚུངས་ཡོད་པ་མ་ཟད། མཉེན་ཆས་དེ་ཉིད་རིན་མེད་ཡིན་པའི་ཆ་ནས་ཨ་རི་དང་ཨིན་ཡུལ། དེ་བཞིན་རྒྱ་ནག་གསུམ་གྱི་ནང་དུ་སྐུ་ཤུ་རྟགས་ཅན་འཕྲུལ་རིག་གི་ App Store གཉེན་ཆས་བངས་མཛོད་ནང་ནས་ Deep Seek དེ་ཉིད་མཉེན་ཆས་ཕབ་ལེན་མང་ཤོས་བྱས་པ་ཞིག་ཆགས་ཡོད་པ་རེད་འདུག"],
146
+ ]
147
 
148
+ gr.Examples(
149
+ examples=examples,
150
+ inputs=[input_box],
151
+ label="Try these examples",
152
+ examples_per_page=3
 
153
  )
 
154
 
155
+ # Connect components with correct inputs
156
+ submit_btn.click(
157
+ fn=stream_translation,
158
+ inputs=[input_box, chatbot],
159
+ outputs=chatbot
160
+ )
161
  if __name__ == "__main__":
162
  demo.queue(max_size=20).launch(share=True)