sanaweb commited on
Commit
9d5d600
·
verified ·
1 Parent(s): bae88f8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +226 -1
app.py CHANGED
@@ -1,3 +1,228 @@
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- gr.load("models/PartAI/Dorna-Llama3-8B-Instruct").launch()
 
 
1
+ import os
2
+ from threading import Thread
3
+ from typing import Iterator
4
+
5
  import gradio as gr
6
+ from langfuse import Langfuse
7
+ from langfuse.decorators import observe
8
+ import spaces
9
+ import torch
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
11
+ import time
12
+
13
+ MAX_MAX_NEW_TOKENS = 1048
14
+ DEFAULT_MAX_NEW_TOKENS = 1024
15
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "1096"))
16
+
17
+
18
+ DESCRIPTION = """\
19
+ # models/PartAI/Dorna-Llama3-8B-Instruct
20
+ """
21
+
22
+ PLACEHOLDER = """
23
+ <div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
24
+ <img src="https://avatars.githubusercontent.com/u/39557177?v=4" style="width: 80%; max-width: 550px; height: auto; opacity: 0.80; ">
25
+ <h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">Dorna-Llama3-8B-Instruct</h1>
26
+ </div>
27
+ """
28
+
29
+ custom_css = """
30
+ @import url('https://fonts.googleapis.com/css2?family=Vazirmatn&display=swap');
31
+ body, .gradio-container, .gr-button, .gr-input, .gr-slider, .gr-dropdown, .gr-markdown {
32
+ font-family: 'Vazirmatn', sans-serif !important;
33
+ }
34
+ ._button {
35
+ font-size: 20px;
36
+ }
37
+ pre, code {
38
+ direction: ltr !important;
39
+ unicode-bidi: plaintext !important;
40
+ }
41
+ """
42
+
43
+
44
+ system_prompt = str(os.getenv("SYSTEM_PROMPT"))
45
+
46
+ secret_key = str(os.getenv("LANGFUSE_SECRET_KEY"))
47
+ public_key = str(os.getenv("LANGFUSE_PUBLIC_KEY"))
48
+ host = str(os.getenv("LANGFUSE_HOST"))
49
+
50
+ langfuse = Langfuse(
51
+ secret_key=secret_key,
52
+ public_key=public_key,
53
+ host=host
54
+ )
55
+
56
+
57
+ def execution_time_calculator(start_time, log=True):
58
+ delta = time.time() - start_time
59
+ if log:
60
+ print("--- %s seconds ---" % (delta))
61
+ return delta
62
+
63
+ def token_per_second_calculator(tokens_count, time_delta):
64
+ return tokens_count/time_delta
65
+
66
+ if not torch.cuda.is_available():
67
+ DESCRIPTION = "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
68
+
69
+
70
+ if torch.cuda.is_available():
71
+ model_id = "PartAI/Dorna-Llama3-8B-Instruct"
72
+ model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16)
73
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
74
+
75
+ generation_speed = 0
76
+
77
+ def get_generation_speed():
78
+ global generation_speed
79
+
80
+ return generation_speed
81
+
82
+ @observe()
83
+ def log_to_langfuse(message, chat_history, max_new_tokens, temperature, top_p, top_k, repetition_penalty, do_sample, generation_speed, model_outputs):
84
+ print(f"generation_speed: {generation_speed}")
85
+ return "".join(model_outputs)
86
+
87
+
88
+ @spaces.GPU
89
+ def generate(
90
+ message: str,
91
+ chat_history: list[tuple[str, str]],
92
+ max_new_tokens: int = 1024,
93
+ temperature: float = 0.6,
94
+ top_p: float = 0.9,
95
+ top_k: int = 50,
96
+ repetition_penalty: float = 1.2,
97
+ do_sample: bool =True,
98
+ ) -> Iterator[str]:
99
+ global generation_speed
100
+ global system_prompt
101
+
102
+ conversation = []
103
+ if system_prompt:
104
+ conversation.append({"role": "system", "content": system_prompt})
105
+ for user, assistant in chat_history:
106
+ conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
107
+ conversation.append({"role": "user", "content": message})
108
+
109
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
110
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
111
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
112
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
113
+ input_ids = input_ids.to(model.device)
114
+
115
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
116
+ generate_kwargs = dict(
117
+ {"input_ids": input_ids},
118
+ streamer=streamer,
119
+ max_new_tokens=max_new_tokens,
120
+ do_sample=do_sample,
121
+ top_p=top_p,
122
+ top_k=top_k,
123
+ temperature=temperature,
124
+ num_beams=1,
125
+ repetition_penalty=repetition_penalty,
126
+ )
127
+
128
+ start_time = time.time()
129
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
130
+ t.start()
131
+
132
+ outputs = []
133
+ sum_tokens = 0
134
+ for text in streamer:
135
+ num_tokens = len(tokenizer.tokenize(text))
136
+ sum_tokens += num_tokens
137
+
138
+ outputs.append(text)
139
+ yield "".join(outputs)
140
+
141
+ time_delta = execution_time_calculator(start_time, log=False)
142
+
143
+ generation_speed = token_per_second_calculator(sum_tokens, time_delta)
144
+
145
+ log_function = log_to_langfuse(
146
+ message=message,
147
+ chat_history=chat_history,
148
+ max_new_tokens=max_new_tokens,
149
+ temperature=temperature,
150
+ top_p=top_p,
151
+ top_k=top_k,
152
+ repetition_penalty=repetition_penalty,
153
+ do_sample=do_sample,
154
+ generation_speed=generation_speed,
155
+ model_outputs=outputs,
156
+ )
157
+
158
+
159
+
160
+
161
+
162
+ chatbot = gr.Chatbot(placeholder=PLACEHOLDER, scale=1, show_copy_button=True, height="68%", rtl=True) #, elem_classes=["chatbot"])
163
+ chat_input = gr.Textbox(show_label=False, lines=2, rtl=True, placeholder="ورودی", show_copy_button=True, scale=4)
164
+ submit_btn = gr.Button(variant="primary", value="ارسال", size="sm", scale=1, elem_classes=["_button"])
165
+
166
+
167
+ chat_interface = gr.ChatInterface(
168
+ fn=generate,
169
+ additional_inputs_accordion=gr.Accordion(label="ورودی‌های اضافی", open=False),
170
+ additional_inputs=[
171
+ gr.Slider(
172
+ label="حداکثر تعداد توکن ها",
173
+ minimum=1,
174
+ maximum=MAX_MAX_NEW_TOKENS,
175
+ step=1,
176
+ value=DEFAULT_MAX_NEW_TOKENS,
177
+ ),
178
+ gr.Slider(
179
+ label="Temperature",
180
+ minimum=0.01,
181
+ maximum=4.0,
182
+ step=0.01,
183
+ value=0.5,
184
+ ),
185
+ gr.Slider(
186
+ label="Top-p",
187
+ minimum=0.05,
188
+ maximum=1.0,
189
+ step=0.01,
190
+ value=0.9,
191
+ ),
192
+ gr.Slider(
193
+ label="Top-k",
194
+ minimum=1,
195
+ maximum=1000,
196
+ step=1,
197
+ value=20,
198
+ ),
199
+ gr.Slider(
200
+ label="جریمه تکرار",
201
+ minimum=1.0,
202
+ maximum=2.0,
203
+ step=0.05,
204
+ value=1.2,
205
+ ),
206
+ gr.Dropdown(
207
+ label="نمونه‌گیری",
208
+ choices=[False, True],
209
+ value=True)
210
+ ],
211
+ stop_btn="توقف",
212
+ chatbot=chatbot,
213
+ textbox=chat_input,
214
+ submit_btn=submit_btn,
215
+ retry_btn="🔄 تلاش مجدد",
216
+ undo_btn="↩️ بازگشت",
217
+ clear_btn="🗑️ پاک کردن",
218
+ title="درنا، محصول مرکز تحقیقات هوش مصنوعی پارت"
219
+ )
220
+
221
+
222
+ with gr.Blocks(css=custom_css, fill_height=False) as demo:
223
+ gr.Markdown(DESCRIPTION)
224
+ chat_interface.render()
225
+
226
 
227
+ if __name__ == "__main__":
228
+ demo.queue(max_size=20).launch()