hfl-rc commited on
Commit
e0dec7a
1 Parent(s): c796c3d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -1,5 +1,5 @@
1
  from threading import Thread
2
-
3
  import gradio as gr
4
  import spaces
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
@@ -17,7 +17,7 @@ BANNER_HTML = """
17
  </center>
18
  </h3>
19
  <p>
20
- <center><em>The demo is mainly for academic purposes. Do not use this demo for illegal activities. Default model: <a href="https://huggingface.co/hfl/llama-3-chinese-8b-instruct-v3">hfl/llama-3-chinese-8b-instruct-v3</a></em></center>
21
  </p>
22
  """
23
 
@@ -34,7 +34,7 @@ def load_model(version):
34
  model_name = "hfl/llama-3-chinese-8b-instruct-v3"
35
 
36
  tokenizer = AutoTokenizer.from_pretrained(model_name)
37
- model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
38
  return f"Model {model_name} loaded."
39
 
40
  @spaces.GPU(duration=50)
@@ -77,7 +77,7 @@ with gr.Blocks() as demo:
77
  gr.Text(value=DEFAULT_SYSTEM_PROMPT, label="System Prompt / 系统提示词", render=False),
78
  gr.Radio(choices=["v1", "v2", "v3"], label="Model Version / 模型版本", value="v3", interactive=False, render=False),
79
  gr.Slider(minimum=0, maximum=1, step=0.1, value=0.5, label="Temperature / 温度系数", render=False),
80
- gr.Slider(minimum=128, maximum=2048, step=1, value=256, label="Max new tokens / 最大生成长度", render=False),
81
  ],
82
  cache_examples=False,
83
  )
 
1
  from threading import Thread
2
+ import torch
3
  import gradio as gr
4
  import spaces
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
 
17
  </center>
18
  </h3>
19
  <p>
20
+ <center><em>The demo is mainly for academic purposes and users are not expected to use this demo for illegal activities.</em></center>
21
  </p>
22
  """
23
 
 
34
  model_name = "hfl/llama-3-chinese-8b-instruct-v3"
35
 
36
  tokenizer = AutoTokenizer.from_pretrained(model_name)
37
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto", attn_implementation="flash_attention_2")
38
  return f"Model {model_name} loaded."
39
 
40
  @spaces.GPU(duration=50)
 
77
  gr.Text(value=DEFAULT_SYSTEM_PROMPT, label="System Prompt / 系统提示词", render=False),
78
  gr.Radio(choices=["v1", "v2", "v3"], label="Model Version / 模型版本", value="v3", interactive=False, render=False),
79
  gr.Slider(minimum=0, maximum=1, step=0.1, value=0.5, label="Temperature / 温度系数", render=False),
80
+ gr.Slider(minimum=512, maximum=2048, step=1, value=256, label="Max new tokens / 最大生成长度", render=False),
81
  ],
82
  cache_examples=False,
83
  )