vilarin commited on
Commit
3738ef6
1 Parent(s): 39dbb23

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -78
app.py CHANGED
@@ -1,98 +1,104 @@
1
- import subprocess
2
- subprocess.run(
3
- 'pip install flash-attn --no-build-isolation',
4
- env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
5
- shell=True
6
- )
7
-
8
  import torch
9
- from PIL import Image
10
  import gradio as gr
11
- import spaces
12
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria
13
- import os
14
  from threading import Thread
15
 
16
-
17
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
18
- MODEL_LIST = "THUDM/LongWriter-glm4-9b"
19
- #MODELS = os.environ.get("MODELS")
20
- #MODEL_NAME = MODELS.split("/")[-1]
21
 
22
- TITLE = "<h1><center>GLM SPACE</center></h1>"
 
 
 
 
 
 
23
 
24
- PLACEHOLDER = f'<h3><center>LongWriter-glm4-9b is trained based on glm-4-9b, and is capable of generating 10,000+ words at once.</center></h3>'
25
 
26
  CSS = """
27
  .duplicate-button {
28
- margin: auto !important;
29
- color: white !important;
30
- background: black !important;
31
- border-radius: 100vh !important;
 
 
 
32
  }
33
  """
34
 
 
 
 
 
 
 
 
 
 
35
  model = AutoModelForCausalLM.from_pretrained(
36
- "THUDM/LongWriter-glm4-9b",
37
- torch_dtype=torch.bfloat16,
38
- device_map="auto",
39
- trust_remote_code=True,
40
- ).eval()
41
-
42
- tokenizer = AutoTokenizer.from_pretrained("THUDM/LongWriter-glm4-9b",trust_remote_code=True, use_fast=False)
43
-
44
- class StopOnTokens(StoppingCriteria):
45
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
46
- # stop_ids = model.config.eos_token_id
47
- stop_ids = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
48
- tokenizer.get_command("<|observation|>")]
49
- for stop_id in stop_ids:
50
- if input_ids[0][-1] == stop_id:
51
- return True
52
- return False
53
-
54
- @spaces.GPU()
55
- def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int):
56
- print(f'message is - {message}')
57
- print(f'history is - {history}')
58
- conversation = []
59
  for prompt, answer in history:
60
- conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
61
- #conversation.append({"role": "user", "content": message})
 
 
62
 
63
- print(f"Conversation is -\n{conversation}")
64
- stop = StopOnTokens()
65
 
66
- input_ids = tokenizer.build_chat_input(message, history=conversation, role='user').input_ids.to(next(model.parameters()).device)
67
- #input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)
68
  streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
69
- eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
70
- tokenizer.get_command("<|observation|>")]
71
-
72
  generate_kwargs = dict(
73
- input_ids=input_ids,
 
 
 
 
 
 
 
74
  streamer=streamer,
75
- max_new_tokens=max_new_tokens,
76
- do_sample=True,
77
- top_k=1,
78
- temperature=temperature,
79
- repetition_penalty=1,
80
- stopping_criteria=StoppingCriteriaList([stop]),
81
- eos_token_id=eos_token_id,
82
  )
83
- #gen_kwargs = {**input_ids, **generate_kwargs}
84
-
85
- thread = Thread(target=model.generate, kwargs=generate_kwargs)
86
- thread.start()
87
- buffer = ""
88
- for new_token in streamer:
89
- if new_token and '<|user|>' not in new_token:
90
- buffer += new_token
91
  yield buffer
92
 
93
- chatbot = gr.Chatbot(height=600, placeholder = PLACEHOLDER)
 
94
 
95
- with gr.Blocks(css=CSS) as demo:
96
  gr.HTML(TITLE)
97
  gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
98
  gr.ChatInterface(
@@ -101,20 +107,52 @@ with gr.Blocks(css=CSS) as demo:
101
  fill_height=True,
102
  additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
103
  additional_inputs=[
 
 
 
 
 
 
 
 
104
  gr.Slider(
105
  minimum=0,
106
  maximum=1,
107
  step=0.1,
108
- value=0.5,
109
  label="Temperature",
110
  render=False,
111
  ),
112
  gr.Slider(
113
- minimum=1024,
114
- maximum=32768,
115
  step=1,
116
- value=4096,
117
- label="Max New Tokens",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  render=False,
119
  ),
120
  ],
@@ -126,7 +164,7 @@ with gr.Blocks(css=CSS) as demo:
126
  ],
127
  cache_examples=False,
128
  )
129
-
130
 
131
  if __name__ == "__main__":
132
- demo.launch()
 
1
+ import os
2
+ import time
3
+ import spaces
 
 
 
 
4
  import torch
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
6
  import gradio as gr
 
 
 
7
  from threading import Thread
8
 
 
9
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
10
+ MODEL = "LGAI-EXAONE/EXAONE-3.5-32B-Instruct"
 
 
11
 
12
+ TITLE = "<h1><center>EXAONE-3.5-32B-Instruct</center></h1>"
13
+
14
+ PLACEHOLDER = """
15
+ <center>
16
+ <p>Hi! How can I help you today?</p>
17
+ </center>
18
+ """
19
 
 
20
 
21
  CSS = """
22
  .duplicate-button {
23
+ margin: auto !important;
24
+ color: white !important;
25
+ background: black !important;
26
+ border-radius: 100vh !important;
27
+ }
28
+ h3 {
29
+ text-align: center;
30
  }
31
  """
32
 
33
+ device = "cuda" # for GPU usage or "cpu" for CPU usage
34
+
35
+ quantization_config = BitsAndBytesConfig(
36
+ load_in_4bit=True,
37
+ bnb_4bit_compute_dtype=torch.bfloat16,
38
+ bnb_4bit_use_double_quant=True,
39
+ bnb_4bit_quant_type= "nf4")
40
+
41
+ tokenizer = AutoTokenizer.from_pretrained(MODEL)
42
  model = AutoModelForCausalLM.from_pretrained(
43
+ MODEL,
44
+ torch_dtype=torch.bfloat16,
45
+ device_map="auto",
46
+ quantization_config=quantization_config)
47
+
48
+ @spaces.GPU(duration=100)
49
+ def stream_chat(
50
+ message: str,
51
+ history: list,
52
+ system_prompt: str,
53
+ temperature: float = 0.8,
54
+ max_new_tokens: int = 1024,
55
+ top_p: float = 1.0,
56
+ top_k: int = 20,
57
+ penalty: float = 1.2,
58
+ ):
59
+ print(f'message: {message}')
60
+ print(f'history: {history}')
61
+
62
+ conversation = [
63
+ {"role": "system", "content": system_prompt}
64
+ ]
 
65
  for prompt, answer in history:
66
+ conversation.extend([
67
+ {"role": "user", "content": prompt},
68
+ {"role": "assistant", "content": answer},
69
+ ])
70
 
71
+ conversation.append({"role": "user", "content": message})
 
72
 
73
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(model.device)
74
+
75
  streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
76
+
 
 
77
  generate_kwargs = dict(
78
+ input_ids=input_ids,
79
+ max_new_tokens = max_new_tokens,
80
+ do_sample = False if temperature == 0 else True,
81
+ top_p = top_p,
82
+ top_k = top_k,
83
+ temperature = temperature,
84
+ repetition_penalty=penalty,
85
+ eos_token_id=255001,
86
  streamer=streamer,
 
 
 
 
 
 
 
87
  )
88
+
89
+ with torch.no_grad():
90
+ thread = Thread(target=model.generate, kwargs=generate_kwargs)
91
+ thread.start()
92
+
93
+ buffer = ""
94
+ for new_text in streamer:
95
+ buffer += new_text
96
  yield buffer
97
 
98
+
99
+ chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
100
 
101
+ with gr.Blocks(css=CSS, theme="soft") as demo:
102
  gr.HTML(TITLE)
103
  gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
104
  gr.ChatInterface(
 
107
  fill_height=True,
108
  additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
109
  additional_inputs=[
110
+ gr.Textbox(
111
+ value="""
112
+ You are a helpful assistant.
113
+ """,
114
+ label="System Prompt",
115
+ lines=5,
116
+ render=False,
117
+ ),
118
  gr.Slider(
119
  minimum=0,
120
  maximum=1,
121
  step=0.1,
122
+ value=0.8,
123
  label="Temperature",
124
  render=False,
125
  ),
126
  gr.Slider(
127
+ minimum=128,
128
+ maximum=8192,
129
  step=1,
130
+ value=1024,
131
+ label="Max new tokens",
132
+ render=False,
133
+ ),
134
+ gr.Slider(
135
+ minimum=0.0,
136
+ maximum=1.0,
137
+ step=0.1,
138
+ value=1.0,
139
+ label="top_p",
140
+ render=False,
141
+ ),
142
+ gr.Slider(
143
+ minimum=1,
144
+ maximum=20,
145
+ step=1,
146
+ value=20,
147
+ label="top_k",
148
+ render=False,
149
+ ),
150
+ gr.Slider(
151
+ minimum=0.0,
152
+ maximum=2.0,
153
+ step=0.1,
154
+ value=1.2,
155
+ label="Repetition penalty",
156
  render=False,
157
  ),
158
  ],
 
164
  ],
165
  cache_examples=False,
166
  )
167
+
168
 
169
  if __name__ == "__main__":
170
+ demo.launch()