MohamedRashad commited on
Commit
675a4cb
ยท
verified ยท
1 Parent(s): e769375

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -12
app.py CHANGED
@@ -1,17 +1,18 @@
1
  import subprocess
2
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
3
 
 
4
  import gradio as gr
5
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
6
  import torch
7
- import os
8
  from threading import Thread
9
 
 
10
  # Load model directly
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
  tokenizer = AutoTokenizer.from_pretrained("Navid-AI/Mulhem-1-Mini", token=os.getenv("HF_TOKEN"))
13
- model = AutoModelForCausalLM.from_pretrained("Navid-AI/Mulhem-1-Mini", torch_dtype=torch.bfloat16, token=os.getenv("HF_TOKEN")).to(device)
14
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
15
 
16
  def respond(
17
  message,
@@ -20,36 +21,38 @@ def respond(
20
  system_message,
21
  max_tokens,
22
  temperature,
 
23
  top_p,
24
  ):
25
  messages = [{"role": "system", "content": system_message}]
26
 
27
  for val in history:
28
  if val[0]:
29
- messages.append({"role": "user", "content": val[0]})
30
  if val[1]:
31
- messages.append({"role": "assistant", "content": val[1]})
32
 
33
  messages.append({"role": "user", "content": message})
34
- inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True, enable_reasoning=enable_reasoning).to(device)
35
- generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p)
 
36
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
37
 
38
  thread.start()
 
39
  for new_text in streamer:
40
- yield new_text
 
41
 
42
 
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
  demo = gr.ChatInterface(
47
  respond,
48
  additional_inputs=[
49
  gr.Checkbox(label="Enable reasoning", value=False),
50
  gr.Textbox(value="ุฃู†ุช ู…ูู„ู‡ู…. ุฐูƒุงุก ุงุตุทู†ุงุนูŠ ุชู… ุฅู†ุดุงุคู‡ ู…ู† ุดุฑูƒุฉ ู†ููŠุฏ ู„ุฅู„ู‡ุงู… ูˆุชุญููŠุฒ ุงู„ู…ุณุชุฎุฏู…ูŠู† ุนู„ู‰ ุงู„ุชุนู„ู‘ู…ุŒ ุงู„ู†ู…ูˆุŒ ูˆุชุญู‚ูŠู‚ ุฃู‡ุฏุงูู‡ู….", label="System message"),
51
  gr.Slider(minimum=1, maximum=8192, value=2048, step=1, label="Max new tokens"),
52
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
 
53
  gr.Slider(
54
  minimum=0.1,
55
  maximum=1.0,
 
1
  import subprocess
2
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
3
 
4
+ import os
5
  import gradio as gr
6
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
7
  import torch
 
8
  from threading import Thread
9
 
10
+
11
  # Load model directly
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
  tokenizer = AutoTokenizer.from_pretrained("Navid-AI/Mulhem-1-Mini", token=os.getenv("HF_TOKEN"))
14
+ model = AutoModelForCausalLM.from_pretrained("Navid-AI/Mulhem-1-Mini", torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", token=os.getenv("HF_TOKEN")).to(device)
15
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
16
 
17
  def respond(
18
  message,
 
21
  system_message,
22
  max_tokens,
23
  temperature,
24
+ repetition_penalty,
25
  top_p,
26
  ):
27
  messages = [{"role": "system", "content": system_message}]
28
 
29
  for val in history:
30
  if val[0]:
31
+ messages.append({"role": "user", "content": val[0].strip()})
32
  if val[1]:
33
+ messages.append({"role": "assistant", "content": val[1].strip()})
34
 
35
  messages.append({"role": "user", "content": message})
36
+ print(messages)
37
+ inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True, enable_reasoning=enable_reasoning, return_dict=True).to(device)
38
+ generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty)
39
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
40
 
41
  thread.start()
42
+ response = ""
43
  for new_text in streamer:
44
+ response += new_text
45
+ yield response
46
 
47
 
 
 
 
48
  demo = gr.ChatInterface(
49
  respond,
50
  additional_inputs=[
51
  gr.Checkbox(label="Enable reasoning", value=False),
52
  gr.Textbox(value="ุฃู†ุช ู…ูู„ู‡ู…. ุฐูƒุงุก ุงุตุทู†ุงุนูŠ ุชู… ุฅู†ุดุงุคู‡ ู…ู† ุดุฑูƒุฉ ู†ููŠุฏ ู„ุฅู„ู‡ุงู… ูˆุชุญููŠุฒ ุงู„ู…ุณุชุฎุฏู…ูŠู† ุนู„ู‰ ุงู„ุชุนู„ู‘ู…ุŒ ุงู„ู†ู…ูˆุŒ ูˆุชุญู‚ูŠู‚ ุฃู‡ุฏุงูู‡ู….", label="System message"),
53
  gr.Slider(minimum=1, maximum=8192, value=2048, step=1, label="Max new tokens"),
54
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.1, step=0.1, label="Temperature"),
55
+ gr.Slider(minimum=0.1, maximum=2.0, value=1.25, step=0.05, label="Repetition penalty"),
56
  gr.Slider(
57
  minimum=0.1,
58
  maximum=1.0,