TenzinGayche commited on
Commit
4fe238f
·
verified ·
1 Parent(s): 3ed952b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -55
app.py CHANGED
@@ -3,15 +3,16 @@ from threading import Thread, Event
3
  from typing import Iterator
4
 
5
  import gradio as gr
6
-
7
  import torch
8
  from transformers import AutoModelForCausalLM, GemmaTokenizerFast, TextIteratorStreamer
 
9
  DESCRIPTION = """\
10
- # Monlam LLM v2.0.1
 
11
  """
12
- path="TenzinGayche/tpo_v1.0.0_ep2_dpo_ft"
13
- MAX_MAX_NEW_TOKENS = 2048
14
- DEFAULT_MAX_NEW_TOKENS = 1024
15
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
16
 
17
  # Load the model and tokenizer
@@ -21,91 +22,98 @@ model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.float16).to
21
  model.config.sliding_window = 4096
22
  model.eval()
23
 
24
- # Create a shared stop event
25
  stop_event = Event()
26
 
27
- def generate(
28
- message: str,
29
- chat_history: list[dict],
30
- max_new_tokens: int = 2048,
 
31
  temperature: float = 0.6,
32
  top_p: float = 0.9,
33
  top_k: int = 50,
34
  repetition_penalty: float = 1.2,
35
- do_sample: bool= False
36
  ) -> Iterator[str]:
37
- # Clear the stop event before starting a new generation
38
  stop_event.clear()
39
 
40
-
41
- # Append the user's message to the conversation history
42
- conversation = chat_history.copy()
43
- if not conversation:
44
- conversation.extend([
45
- {
46
- "role": "user",
47
- "content": "ཁྱེད་རང་སྨོན་ལམ་མི་བཟོས་རིག་ནུས་ཤིག་ཡིན་པ་དང་ཁྱེད་རང་མི་བཟོས་རིག་ནུས་(AI)ཤིག་ཡིན།"
48
- },
49
- {
50
- "role": "assistant",
51
- "content": "ལགས་སོ། ང་ཡིས་ཁྱེད་ཀྱི་བཀའ་བཞིན་སྒྲུབ་ཆོག"
52
- }
53
- ])
54
- conversation.append({"role": "user", "content": message})
55
-
56
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
57
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
58
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
59
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
60
  input_ids = input_ids.to(model.device)
61
 
62
- # Create a streamer to get the generated response
63
  streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
64
  generate_kwargs = dict(
65
- {"input_ids": input_ids},
66
  streamer=streamer,
67
  max_new_tokens=max_new_tokens,
68
-
69
  )
70
-
71
- # Run generation in a background thread
72
  t = Thread(target=model.generate, kwargs=generate_kwargs)
73
  t.start()
74
 
75
  outputs = []
 
 
76
  for text in streamer:
77
  if stop_event.is_set():
78
- break # Stop if the stop button is pressed
79
- outputs.append(text)
80
- yield "".join(outputs)
81
 
82
- # After generation, append the assistant's response to the chat history
83
- assistant_response = "".join(outputs)
84
- chat_history.append({"role": "assistant", "content": assistant_response})
 
 
85
 
 
 
 
 
 
 
86
 
87
- # Define a function to stop the generation
 
 
 
 
88
  def stop_generation():
89
  stop_event.set()
90
 
91
- # Create the chat interface with additional inputs and the stop button
 
92
  with gr.Blocks(css="style.css", fill_height=True) as demo:
93
  gr.Markdown(DESCRIPTION)
94
 
95
- # Create the chat interface
96
- chat_interface = gr.ChatInterface(
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  fn=generate,
98
- examples=[
99
- ["Hello there! How are you doing?"],
100
- ["Can you explain briefly to me what is the Python programming language?"],
101
- ["Explain the plot of Cinderella in a sentence."],
102
- ["How many hours does it take a man to eat a Helicopter?"],
103
- ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
104
- ],
105
- cache_examples=False,
106
- type="messages",
107
  )
108
-
109
 
110
  if __name__ == "__main__":
111
- demo.queue(max_size=20).launch(share=True)
 
3
  from typing import Iterator
4
 
5
  import gradio as gr
 
6
  import torch
7
  from transformers import AutoModelForCausalLM, GemmaTokenizerFast, TextIteratorStreamer
8
+
9
  DESCRIPTION = """\
10
+ # Monlam LLM v2.0.1 - Thoughts and Translation
11
+ This version generates detailed reasoning (thoughts) followed by a tokenized translation.
12
  """
13
+
14
+ # Constants
15
+ path = "TenzinGayche/tpo_v1.0.0_dpo_2_3ep_ft"
16
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
17
 
18
  # Load the model and tokenizer
 
22
  model.config.sliding_window = 4096
23
  model.eval()
24
 
25
+ # Shared stop event
26
  stop_event = Event()
27
 
28
+
29
+ # Generate function
30
+ def generate(message: str,
31
+ show_thoughts: bool,
32
+ max_new_tokens: int = 1024,
33
  temperature: float = 0.6,
34
  top_p: float = 0.9,
35
  top_k: int = 50,
36
  repetition_penalty: float = 1.2,
37
+ do_sample: bool = False,
38
  ) -> Iterator[str]:
 
39
  stop_event.clear()
40
 
41
+ # Prepare input for the model
42
+ conversation = [
43
+ {"role": "user", "content": f"Please translate the following into Germany: {message} Translation:"}
44
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
45
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
46
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
47
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
48
+ gr.Warning(f"Input trimmed as it exceeded {MAX_INPUT_TOKEN_LENGTH} tokens.")
49
  input_ids = input_ids.to(model.device)
50
 
51
+ # Use a streamer to get generated text
52
  streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
53
  generate_kwargs = dict(
54
+ input_ids=input_ids,
55
  streamer=streamer,
56
  max_new_tokens=max_new_tokens,
 
57
  )
58
+
59
+ # Generate in a separate thread
60
  t = Thread(target=model.generate, kwargs=generate_kwargs)
61
  t.start()
62
 
63
  outputs = []
64
+ in_translation = False
65
+
66
  for text in streamer:
67
  if stop_event.is_set():
68
+ break
 
 
69
 
70
+ # Process the generated text
71
+ if "#Final Translation:" in text and not in_translation:
72
+ in_translation = True
73
+ if not show_thoughts:
74
+ text = text.split("#Final Translation:", 1)[1].strip() # Skip reasoning if "View Thoughts" is disabled
75
 
76
+ if in_translation:
77
+ outputs.append(text)
78
+ yield "".join(outputs)
79
+ elif show_thoughts:
80
+ outputs.append(text)
81
+ yield "".join(outputs)
82
 
83
+ # Append assistant's response
84
+ chat_history = "".join(outputs)
85
+
86
+
87
+ # Stop generation function
88
  def stop_generation():
89
  stop_event.set()
90
 
91
+
92
+ # Create the Gradio interface
93
  with gr.Blocks(css="style.css", fill_height=True) as demo:
94
  gr.Markdown(DESCRIPTION)
95
 
96
+ with gr.Row():
97
+ input_text = gr.Textbox(label="Enter Tibetan text", placeholder="Type Tibetan text here...")
98
+ show_thoughts = gr.Checkbox(label="View Detailed Thoughts", value=True)
99
+ submit_button = gr.Button("Translate")
100
+ stop_button = gr.Button("Stop")
101
+
102
+ with gr.Row():
103
+ output_area = gr.Textbox(
104
+ label="Output (Thoughts and Translation)",
105
+ lines=20,
106
+ interactive=False,
107
+ )
108
+
109
+ # Connect buttons to functions
110
+ submit_button.click(
111
  fn=generate,
112
+ inputs=[input_text, show_thoughts],
113
+ outputs=output_area,
114
+ queue=True, # Enable streaming
 
 
 
 
 
 
115
  )
116
+ stop_button.click(stop_generation)
117
 
118
  if __name__ == "__main__":
119
+ demo.queue(max_size=20).launch(share=True)