Azazelle commited on
Commit
bb221fb
·
verified ·
1 Parent(s): ec52c4d

Update app.py

Browse files

Speed up with caching.

Files changed (1) hide show
  1. app.py +16 -9
app.py CHANGED
@@ -50,6 +50,7 @@ if model_big.device == "cuda":
50
  if model_small.device == "cuda":
51
  model_small = torch.compile(model_small)
52
 
 
53
  @spaces.GPU
54
  def stream_chat(
55
  message: str,
@@ -73,20 +74,22 @@ def stream_chat(
73
 
74
  conversation.append({"role": "user", "content": message})
75
 
76
- input_text = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
77
- inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
78
 
79
  generated_tokens = []
80
  current_input = inputs
 
 
 
81
 
82
  for _ in range(max_new_tokens):
83
- with torch.no_grad():
84
- logits_small = model_small(current_input).logits[:, -1, :]
85
- logits_big = model_big(current_input).logits[:, -1, :]
86
-
87
- probs_small = torch.softmax(logits_small / temperature, dim=-1)
88
 
89
- interpolated_logits = logits_big + (guidance_scale - 1) * (logits_big - logits_small) * probs_small
90
 
91
  if top_p < 1.0:
92
  interpolated_logits = top_p_filtering(interpolated_logits, top_p=top_p)
@@ -99,7 +102,11 @@ def stream_chat(
99
  break
100
 
101
  generated_tokens.append(next_token.item())
102
- current_input = torch.cat([current_input, next_token], dim=1)
 
 
 
 
103
 
104
  partial_output = tokenizer.decode(generated_tokens, skip_special_tokens=True)
105
  yield partial_output
 
50
  if model_small.device == "cuda":
51
  model_small = torch.compile(model_small)
52
 
53
+ @torch.no_grad()
54
  @spaces.GPU
55
  def stream_chat(
56
  message: str,
 
74
 
75
  conversation.append({"role": "user", "content": message})
76
 
77
+ inputs = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt")
 
78
 
79
  generated_tokens = []
80
  current_input = inputs
81
+
82
+ cache_small = None
83
+ cache_big = None
84
 
85
  for _ in range(max_new_tokens):
86
+ outputs_small = model_small(current_input, use_cache=True, past_key_values=cache_small)
87
+ outputs_big = model_big(current_input, use_cache=True, past_key_values=cache_big)
88
+
89
+ logits_small = outputs_small.logits[:, -1, :]
90
+ logits_big = outputs_big.logits[:, -1, :]
91
 
92
+ interpolated_logits = logits_big + (guidance_scale - 1) * (logits_big - logits_small)
93
 
94
  if top_p < 1.0:
95
  interpolated_logits = top_p_filtering(interpolated_logits, top_p=top_p)
 
102
  break
103
 
104
  generated_tokens.append(next_token.item())
105
+ current_input = next_token
106
+
107
+ # Update the cache with the latest past_key_values
108
+ cache_small = outputs_small.past_key_values
109
+ cache_big = outputs_big.past_key_values
110
 
111
  partial_output = tokenizer.decode(generated_tokens, skip_special_tokens=True)
112
  yield partial_output