Azazelle commited on
Commit
105395c
·
verified ·
1 Parent(s): 5982f31

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -42
app.py CHANGED
@@ -6,22 +6,16 @@ import torch
6
 
7
  print(torch.__version__)
8
 
9
- # install packages for mamba
10
- def install_mamba():
11
- subprocess.run(shlex.split("pip install https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.4.0/causal_conv1d-1.4.0+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"))
12
- subprocess.run(shlex.split("pip install https://github.com/state-spaces/mamba/releases/download/v2.2.2/mamba_ssm-2.2.2+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"))
13
-
14
- install_mamba()
15
-
16
  import torch
17
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
18
  import gradio as gr
19
  from threading import Thread
20
 
21
- MODEL = "tiiuae/falcon-mamba-7b-instruct"
 
22
 
23
- TITLE = "<h1><center>FalconMamba-7b playground</center></h1>"
24
- SUB_TITLE = """<center>FalconMamba is a new model released by Technology Innovation Institute (TII) in Abu Dhabi. The model is open source and available within the Hugging Face ecosystem for anyone to use it for their research or application purpose. Refer to <a href="https://hf.co/blog/falconmamba">the HF release blogpost</a> or <a href="https://www.tii.ae/news/uaes-technology-innovation-institute-revolutionizes-ai-language-models-new-architecture">the official announcement</a> for more details. This interface has been created for quick validation purposes, do not use it for production.</center>"""
25
 
26
  CSS = """
27
  .duplicate-button {
@@ -40,19 +34,21 @@ END_MESSAGE = """
40
  **The conversation has reached to its end, please press "Clear" to restart a new conversation**
41
  """
42
 
43
- device = "cuda" # for GPU usage or "cpu" for CPU usage
44
 
45
  tokenizer = AutoTokenizer.from_pretrained(MODEL)
46
- model = AutoModelForCausalLM.from_pretrained(
47
- MODEL,
 
 
 
 
48
  torch_dtype=torch.bfloat16,
49
  ).to(device)
50
 
51
  if device == "cuda":
52
  model = torch.compile(model)
53
 
54
-
55
-
56
  @spaces.GPU
57
  def stream_chat(
58
  message: str,
@@ -62,6 +58,7 @@ def stream_chat(
62
  top_p: float = 1.0,
63
  top_k: int = 20,
64
  penalty: float = 1.2,
 
65
  ):
66
  print(f'message: {message}')
67
  print(f'history: {history}')
@@ -73,37 +70,59 @@ def stream_chat(
73
  {"role": "assistant", "content": answer},
74
  ])
75
 
76
-
77
  conversation.append({"role": "user", "content": message})
78
 
79
-
80
- input_text = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt = True)
81
-
82
  inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
83
- streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
84
 
85
- generate_kwargs = dict(
86
- input_ids=inputs,
87
- max_new_tokens = max_new_tokens,
88
- do_sample = False if temperature == 0 else True,
89
- top_p = top_p,
90
- top_k = top_k,
91
- temperature = temperature,
92
- streamer=streamer,
93
- pad_token_id = 10,
94
- )
95
-
96
- with torch.no_grad():
97
- thread = Thread(target=model.generate, kwargs=generate_kwargs)
98
- thread.start()
99
-
100
- buffer = ""
101
- for new_text in streamer:
102
- buffer += new_text
103
- yield buffer
104
-
105
-
106
- print(f'response: {buffer}')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
  chatbot = gr.Chatbot(height=600)
109
 
@@ -157,6 +176,14 @@ with gr.Blocks(css=CSS, theme="soft") as demo:
157
  label="Repetition penalty",
158
  render=False,
159
  ),
 
 
 
 
 
 
 
 
160
  ],
161
  examples=[
162
  ["Hello there, can you suggest few places to visit in UAE?"],
 
6
 
7
  print(torch.__version__)
8
 
 
 
 
 
 
 
 
9
  import torch
10
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
11
  import gradio as gr
12
  from threading import Thread
13
 
14
+ MODEL_BIG = "HuggingFaceTB/SmolLM-1.7B-Instruct"
15
+ MODEL_SMALL = "HuggingFaceTB/SmolLM-360M-Instruct"
16
 
17
+ TITLE = "<h1><center>Auto-Guidance Playground</center></h1>"
18
+ SUB_TITLE = """<center>Auto-guidance was a technique made by NVIDIA for text-conditioned image models. This is a test of the concept with SmolLM.</center>"""
19
 
20
  CSS = """
21
  .duplicate-button {
 
34
  **The conversation has reached to its end, please press "Clear" to restart a new conversation**
35
  """
36
 
37
+ device = "cpu" # for GPU usage or "cpu" for CPU usage
38
 
39
  tokenizer = AutoTokenizer.from_pretrained(MODEL)
40
+ model_big = AutoModelForCausalLM.from_pretrained(
41
+ MODEL_BIG,
42
+ torch_dtype=torch.bfloat16,
43
+ ).to(device)
44
+ model_small = AutoModelForCausalLM.from_pretrained(
45
+ MODEL_SMALL,
46
  torch_dtype=torch.bfloat16,
47
  ).to(device)
48
 
49
  if device == "cuda":
50
  model = torch.compile(model)
51
 
 
 
52
  @spaces.GPU
53
  def stream_chat(
54
  message: str,
 
58
  top_p: float = 1.0,
59
  top_k: int = 20,
60
  penalty: float = 1.2,
61
+ guidance_scale: float = 1.5,
62
  ):
63
  print(f'message: {message}')
64
  print(f'history: {history}')
 
70
  {"role": "assistant", "content": answer},
71
  ])
72
 
 
73
  conversation.append({"role": "user", "content": message})
74
 
75
+ input_text = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
 
 
76
  inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
 
77
 
78
+ generated_tokens = []
79
+ current_input = inputs
80
+
81
+ for _ in range(max_new_tokens):
82
+ with torch.no_grad():
83
+ logits_small = model_small(current_input).logits[:, -1, :]
84
+ logits_big = model_big(current_input).logits[:, -1, :]
85
+
86
+ probs_small = torch.softmax(logits_small / temperature, dim=-1)
87
+ probs_big = torch.softmax(logits_big / temperature, dim=-1)
88
+
89
+ interpolated_logits = logits_big + (guidance_scale - 1) * (logits_big - logits_small) * probs_small
90
+
91
+ if top_p < 1.0:
92
+ interpolated_logits = top_p_filtering(interpolated_logits, top_p=top_p)
93
+ if top_k > 0:
94
+ interpolated_logits = top_k_filtering(interpolated_logits, top_k=top_k)
95
+
96
+ next_token = torch.multinomial(torch.softmax(interpolated_logits, dim=-1), num_samples=1)
97
+
98
+ if next_token.item() == tokenizer.eos_token_id:
99
+ break
100
+
101
+ generated_tokens.append(next_token.item())
102
+ current_input = torch.cat([current_input, next_token], dim=1)
103
+
104
+ partial_output = tokenizer.decode(generated_tokens, skip_special_tokens=True)
105
+ yield partial_output
106
+
107
+ print(f'response: {partial_output}')
108
+
109
+ def top_k_filtering(logits, top_k=0, filter_value=-float('Inf')):
110
+ top_k = min(top_k, logits.size(-1))
111
+ if top_k > 0:
112
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
113
+ logits[indices_to_remove] = filter_value
114
+ return logits
115
+
116
+ def top_p_filtering(logits, top_p=0.0, filter_value=-float('Inf')):
117
+ if top_p > 0.0:
118
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
119
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
120
+ sorted_indices_to_remove = cumulative_probs > top_p
121
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
122
+ sorted_indices_to_remove[..., 0] = 0
123
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
124
+ logits[indices_to_remove] = filter_value
125
+ return logits
126
 
127
  chatbot = gr.Chatbot(height=600)
128
 
 
176
  label="Repetition penalty",
177
  render=False,
178
  ),
179
+ gr.Slider(
180
+ minimum=0.0,
181
+ maximum=10.0,
182
+ step=0.1,
183
+ value=1.5,
184
+ label="Auto-Guidance Scale",
185
+ render=False,
186
+ ),
187
  ],
188
  examples=[
189
  ["Hello there, can you suggest few places to visit in UAE?"],