CMLL commited on
Commit
fdf8c66
1 Parent(s): f3b7005

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -51
app.py CHANGED
@@ -12,61 +12,58 @@ DEFAULT_MAX_NEW_TOKENS = 1024
12
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
13
 
14
  DESCRIPTION = """\
15
-
16
- ZhongJing-2-1_8b Chat
17
- This Space demonstrates the ZhongJing-2-1_8b model, a fine-tuned model for chat instructions. Feel free to play with it, or duplicate to run generations without a queue! If you want to run your own service, you can also deploy the model on Inference Endpoints.
18
  """
19
 
20
  LICENSE = """
21
-
22
  <p/>
23
  ---
24
- As a derivate work of [ZhongJing-2-1_8b](https://huggingface.co/CMLM/ZhongJing-2-1_8b) by 医哲未来 of Fudan University, this demo is governed by the original license and acceptable use policy.
25
  """
26
 
27
- if not torch.cuda.is_available():
28
- DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
29
-
30
- if torch.cuda.is_available():
31
- base_model_id = "Qwen/Qwen1.5-1.8B-Chat"
32
- peft_model_id = "CMLM/ZhongJing-2-1_8b"
33
- model = AutoModelForCausalLM.from_pretrained(base_model_id, torch_dtype=torch.float16, device_map="auto")
34
- model.load_adapter(peft_model_id)
35
- tokenizer = AutoTokenizer.from_pretrained(
36
- peft_model_id,
37
- padding_side="right",
38
- trust_remote_code=True,
39
- pad_token=''
40
- )
41
 
42
- @spaces.GPU
43
  def generate(
44
  message: str,
45
- chat_history: list[tuple[str, str]],
46
- system_prompt: str,
47
  max_new_tokens: int = 1024,
48
  temperature: float = 0.6,
49
- top_p: float = 0.9,
50
  top_k: int = 50,
51
  repetition_penalty: float = 1.2,
52
  ) -> Iterator[str]:
53
- conversation = []
54
- if system_prompt:
55
- conversation.append({"role": "system", "content": system_prompt})
56
- for user, assistant in chat_history:
57
- conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
58
- conversation.append({"role": "user", "content": message})
59
 
60
- input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
 
 
 
 
 
 
 
 
 
 
 
 
61
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
62
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
63
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
64
- input_ids = input_ids.to(model.device)
65
 
66
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
67
  generate_kwargs = dict(
68
  input_ids=input_ids,
69
- streamer=streamer,
70
  max_new_tokens=max_new_tokens,
71
  do_sample=True,
72
  top_p=top_p,
@@ -83,54 +80,55 @@ def generate(
83
  outputs.append(text)
84
  yield "".join(outputs)
85
 
86
- chat_interface = gr.ChatInterface(
87
  fn=generate,
88
- additional_inputs=[
89
- gr.Textbox(label="System prompt", lines=6),
90
- gr.Slider(
91
  label="Max new tokens",
92
- minimum=1,
93
  maximum=MAX_MAX_NEW_TOKENS,
94
  step=1,
95
  value=DEFAULT_MAX_NEW_TOKENS,
96
  ),
97
- gr.Slider(
98
  label="Temperature",
99
  minimum=0.1,
100
  maximum=4.0,
101
- step=0.1,
102
  value=0.6,
103
  ),
104
- gr.Slider(
105
  label="Top-p (nucleus sampling)",
106
  minimum=0.05,
107
  maximum=1.0,
108
  step=0.05,
109
  value=0.9,
110
  ),
111
- gr.Slider(
112
  label="Top-k",
113
  minimum=1,
114
  maximum=1000,
115
  step=1,
116
  value=50,
117
  ),
118
- gr.Slider(
119
- label="Repetition penalty",
120
  minimum=1.0,
121
  maximum=2.0,
122
  step=0.05,
123
  value=1.2,
124
  ),
125
  ],
126
- stop_btn=None,
 
 
 
127
  examples=[
128
- ["Hello there! How are you doing?"],
129
- ["Can you explain briefly to me what is the Python programming language?"],
130
- ["Explain the plot of Cinderella in a sentence."],
131
- ["How many hours does it take a man to eat a Helicopter?"],
132
- ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
133
- ],
134
  )
135
 
136
  with gr.Blocks(css="style.css") as demo:
@@ -140,4 +138,4 @@ with gr.Blocks(css="style.css") as demo:
140
  gr.Markdown(LICENSE)
141
 
142
  if __name__ == "__main__":
143
- demo.queue(max_size=20).launch()
 
12
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
13
 
14
  DESCRIPTION = """\
15
+ 仲景GPT-V2-1.8B
16
+ 博极医源,精勤不倦。Unlocking the Wisdom of Traditional Chinese Medicine with AI.
 
17
  """
18
 
19
  LICENSE = """
 
20
  <p/>
21
  ---
22
+ This demo is governed by the original licenses of [ZhongJing-2-1_8b](https://huggingface.co/CMLM/ZhongJing-2-1_8b) and [Qwen1.5-1.8B-Chat](https://huggingface.co/Qwen/Qwen1.5-1.8B-Chat).
23
  """
24
 
25
+ peft_model_id = "CMLM/ZhongJing-2-1_8b"
26
+ base_model_id = "Qwen/Qwen1.5-1.8B-Chat"
27
+ model = AutoModelForCausalLM.from_pretrained(base_model_id, device_map="auto")
28
+ model.load_adapter(peft_model_id)
29
+ tokenizer = AutoTokenizer.from_pretrained(
30
+ "CMLM/ZhongJing-2-1_8b",
31
+ padding_side="right",
32
+ trust_remote_code=True,
33
+ pad_token=''
34
+ )
 
 
 
 
35
 
36
+ @spaces.gpu()
37
  def generate(
38
  message: str,
 
 
39
  max_new_tokens: int = 1024,
40
  temperature: float = 0.6,
41
+ top_p: float = 0.9,
42
  top_k: int = 50,
43
  repetition_penalty: float = 1.2,
44
  ) -> Iterator[str]:
 
 
 
 
 
 
45
 
46
+ prompt = f"Question: {message}"
47
+ messages = [
48
+ {"role": "system", "content": "You are a helpful TCM medical assistant named 仲景中医大语言模型, created by 医哲未来 of Fudan University."},
49
+ {"role": "user", "content": prompt}
50
+ ]
51
+
52
+ text = tokenizer.apply_chat_template(
53
+ messages,
54
+ tokenize=False,
55
+ add_generation_prompt=True
56
+ )
57
+ input_ids = tokenizer([text], return_tensors="pt").input_ids
58
+
59
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
60
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
61
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
 
62
 
63
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
64
  generate_kwargs = dict(
65
  input_ids=input_ids,
66
+ streamer=streamer,
67
  max_new_tokens=max_new_tokens,
68
  do_sample=True,
69
  top_p=top_p,
 
80
  outputs.append(text)
81
  yield "".join(outputs)
82
 
83
+ chat_interface = gr.Interface(
84
  fn=generate,
85
+ inputs=[
86
+ gr.components.Textbox(label="Enter your question"),
87
+ gr.components.Slider(
88
  label="Max new tokens",
89
+ minimum=1,
90
  maximum=MAX_MAX_NEW_TOKENS,
91
  step=1,
92
  value=DEFAULT_MAX_NEW_TOKENS,
93
  ),
94
+ gr.components.Slider(
95
  label="Temperature",
96
  minimum=0.1,
97
  maximum=4.0,
98
+ step=0.1,
99
  value=0.6,
100
  ),
101
+ gr.components.Slider(
102
  label="Top-p (nucleus sampling)",
103
  minimum=0.05,
104
  maximum=1.0,
105
  step=0.05,
106
  value=0.9,
107
  ),
108
+ gr.components.Slider(
109
  label="Top-k",
110
  minimum=1,
111
  maximum=1000,
112
  step=1,
113
  value=50,
114
  ),
115
+ gr.components.Slider(
116
+ label="Repetition penalty",
117
  minimum=1.0,
118
  maximum=2.0,
119
  step=0.05,
120
  value=1.2,
121
  ),
122
  ],
123
+ outputs="text",
124
+ title="仲景GPT-V2-1.8B",
125
+ description=DESCRIPTION,
126
+ allow_flagging=False,
127
  examples=[
128
+ ["请问气虚体质有哪些症状表现?"],
129
+ ["简单介绍一下中医的五行学说。"],
130
+ ["桑螵蛸是什么?有什么功效作用?"],
131
+ ],
 
 
132
  )
133
 
134
  with gr.Blocks(css="style.css") as demo:
 
138
  gr.Markdown(LICENSE)
139
 
140
  if __name__ == "__main__":
141
+ demo.queue(max_size=20).launch()