TenzinGayche commited on
Commit
717452a
·
verified ·
1 Parent(s): 4fa525d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -40
app.py CHANGED
@@ -1,50 +1,137 @@
 
 
 
 
1
  import gradio as gr
 
2
  import torch
3
- from transformers import AutoModelForCausalLM, GemmaTokenizerFast, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
4
- from threading import Thread
5
-
6
- # Load tokenizer and model
7
- tokenizer = GemmaTokenizerFast.from_pretrained("buddhist-nlp/gemma2-mitra-bo-instruct")
8
- model = AutoModelForCausalLM.from_pretrained("buddhist-nlp/gemma2-mitra-bo-instruct", torch_dtype=torch.float16).to('cuda:0')
9
-
10
- # Define custom stopping criteria
11
- class StopOnTokens(StoppingCriteria):
12
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
13
- # Define stop tokens (adjust based on your model's tokenizer)
14
- stop_ids = [29, 0] # These should be the token IDs for end of response or similar tokens
15
- for stop_id in stop_ids:
16
- if input_ids[0][-1] == stop_id:
17
- return True
18
- return False
19
-
20
- # Define prediction function for the chat interface
21
- def predict(message, history):
22
- # Format the input according to your specified structure
23
- formatted_input = f"### user : {message} ### input: ### answer:"
24
-
25
- # Tokenize the input
26
- model_inputs = tokenizer([formatted_input], return_tensors="pt").to("cuda")
27
-
28
- # Set up the streamer for partial message output
29
- streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
30
-
31
- # Generate settings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  generate_kwargs = dict(
33
- model_inputs,
34
  streamer=streamer,
35
- max_new_tokens=1024
36
  )
37
 
38
- # Run generation in a separate thread
39
  t = Thread(target=model.generate, kwargs=generate_kwargs)
40
  t.start()
41
 
42
- # Stream partial messages as they are generated
43
- partial_message = ""
44
- for new_token in streamer:
45
- if new_token != '<': # Skip specific tokens if necessary
46
- partial_message += new_token
47
- yield partial_message
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
- # Create the chat interface using Gradio
50
- gr.ChatInterface(fn=predict, title="Monlam LLM", description="").launch(share=True)
 
1
+ import os
2
+ from threading import Thread, Event
3
+ from typing import Iterator
4
+
5
  import gradio as gr
6
+
7
  import torch
8
+ from transformers import AutoModelForCausalLM, GemmaTokenizerFast, TextIteratorStreamer
9
+
10
+ DESCRIPTION = """\
11
+ # Gemma 2 2B IT
12
+ Gemma 2 is Google's latest iteration of open LLMs.
13
+ This is a demo of [`google/gemma-2-2b-it`](https://huggingface.co/google/gemma-2-2b-it), fine-tuned for instruction following.
14
+ For more details, please check [our post](https://huggingface.co/blog/gemma2).
15
+ 👉 Looking for a larger and more powerful version? Try the 27B version in [HuggingChat](https://huggingface.co/chat/models/google/gemma-2-27b-it) and the 9B version in [this Space](https://huggingface.co/spaces/huggingface-projects/gemma-2-9b-it).
16
+ """
17
+
18
+ MAX_MAX_NEW_TOKENS = 2048
19
+ DEFAULT_MAX_NEW_TOKENS = 1024
20
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
21
+
22
+ # Load the model and tokenizer
23
+ tokenizer = GemmaTokenizerFast.from_pretrained("TenzinGayche/example")
24
+ model = AutoModelForCausalLM.from_pretrained("TenzinGayche/example", torch_dtype=torch.float16).to("cuda")
25
+
26
+ model.config.sliding_window = 4096
27
+ model.eval()
28
+
29
+ # Create a shared stop event
30
+ stop_event = Event()
31
+
32
+ def generate(
33
+ message: str,
34
+ chat_history: list[dict],
35
+ max_new_tokens: int = 1024,
36
+ temperature: float = 0.6,
37
+ top_p: float = 0.9,
38
+ top_k: int = 50,
39
+ repetition_penalty: float = 1.2,
40
+ ) -> Iterator[str]:
41
+ # Clear the stop event before starting a new generation
42
+ stop_event.clear()
43
+
44
+ conversation = chat_history.copy()
45
+ conversation.append({"role": "user", "content": message})
46
+
47
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
48
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
49
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
50
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
51
+ input_ids = input_ids.to(model.device)
52
+
53
+ streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
54
  generate_kwargs = dict(
55
+ {"input_ids": input_ids},
56
  streamer=streamer,
57
+ max_new_tokens=max_new_tokens,
58
  )
59
 
 
60
  t = Thread(target=model.generate, kwargs=generate_kwargs)
61
  t.start()
62
 
63
+ outputs = []
64
+ for text in streamer:
65
+ if stop_event.is_set():
66
+ break # Stop if the stop button is pressed
67
+ outputs.append(text)
68
+ yield "".join(outputs)
69
+
70
+ # Define a function to stop the generation
71
+ def stop_generation():
72
+ stop_event.set()
73
+
74
+ # Create the chat interface with additional inputs and the stop button
75
+ with gr.Blocks(css="style.css", fill_height=True) as demo:
76
+ gr.Markdown(DESCRIPTION)
77
+
78
+ # Create the chat interface
79
+ chat_interface = gr.ChatInterface(
80
+ fn=generate,
81
+ additional_inputs=[
82
+ gr.Slider(
83
+ label="Max new tokens",
84
+ minimum=1,
85
+ maximum=MAX_MAX_NEW_TOKENS,
86
+ step=1,
87
+ value=DEFAULT_MAX_NEW_TOKENS,
88
+ ),
89
+ gr.Slider(
90
+ label="Temperature",
91
+ minimum=0.1,
92
+ maximum=4.0,
93
+ step=0.1,
94
+ value=0.6,
95
+ ),
96
+ gr.Slider(
97
+ label="Top-p (nucleus sampling)",
98
+ minimum=0.05,
99
+ maximum=1.0,
100
+ step=0.05,
101
+ value=0.9,
102
+ ),
103
+ gr.Slider(
104
+ label="Top-k",
105
+ minimum=1,
106
+ maximum=1000,
107
+ step=1,
108
+ value=50,
109
+ ),
110
+ gr.Slider(
111
+ label="Repetition penalty",
112
+ minimum=1.0,
113
+ maximum=2.0,
114
+ step=0.05,
115
+ value=1.2,
116
+ ),
117
+ ],
118
+ examples=[
119
+ ["Hello there! How are you doing?"],
120
+ ["Can you explain briefly to me what is the Python programming language?"],
121
+ ["Explain the plot of Cinderella in a sentence."],
122
+ ["How many hours does it take a man to eat a Helicopter?"],
123
+ ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
124
+ ],
125
+ cache_examples=False,
126
+ type="messages",
127
+ )
128
+
129
+ # Create the stop button inside the Blocks context
130
+ stop_button = gr.Button("Stop", elem_id="stop-btn")
131
+ stop_button.click(fn=stop_generation, inputs=[], outputs=[])
132
+
133
+ gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
134
+ chat_interface.render()
135
 
136
+ if __name__ == "__main__":
137
+ demo.queue(max_size=20).launch(share=True)