Spaces:
Sleeping
Sleeping
Update app.py
Browse filesSpeed up with caching.
app.py
CHANGED
@@ -50,6 +50,7 @@ if model_big.device == "cuda":
|
|
50 |
if model_small.device == "cuda":
|
51 |
model_small = torch.compile(model_small)
|
52 |
|
|
|
53 |
@spaces.GPU
|
54 |
def stream_chat(
|
55 |
message: str,
|
@@ -73,20 +74,22 @@ def stream_chat(
|
|
73 |
|
74 |
conversation.append({"role": "user", "content": message})
|
75 |
|
76 |
-
|
77 |
-
inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
|
78 |
|
79 |
generated_tokens = []
|
80 |
current_input = inputs
|
|
|
|
|
|
|
81 |
|
82 |
for _ in range(max_new_tokens):
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
|
89 |
-
interpolated_logits = logits_big + (guidance_scale - 1) * (logits_big - logits_small)
|
90 |
|
91 |
if top_p < 1.0:
|
92 |
interpolated_logits = top_p_filtering(interpolated_logits, top_p=top_p)
|
@@ -99,7 +102,11 @@ def stream_chat(
|
|
99 |
break
|
100 |
|
101 |
generated_tokens.append(next_token.item())
|
102 |
-
current_input =
|
|
|
|
|
|
|
|
|
103 |
|
104 |
partial_output = tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
105 |
yield partial_output
|
|
|
50 |
if model_small.device == "cuda":
|
51 |
model_small = torch.compile(model_small)
|
52 |
|
53 |
+
@torch.no_grad()
|
54 |
@spaces.GPU
|
55 |
def stream_chat(
|
56 |
message: str,
|
|
|
74 |
|
75 |
conversation.append({"role": "user", "content": message})
|
76 |
|
77 |
+
inputs = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt")
|
|
|
78 |
|
79 |
generated_tokens = []
|
80 |
current_input = inputs
|
81 |
+
|
82 |
+
cache_small = None
|
83 |
+
cache_big = None
|
84 |
|
85 |
for _ in range(max_new_tokens):
|
86 |
+
outputs_small = model_small(current_input, use_cache=True, past_key_values=cache_small)
|
87 |
+
outputs_big = model_big(current_input, use_cache=True, past_key_values=cache_big)
|
88 |
+
|
89 |
+
logits_small = outputs_small.logits[:, -1, :]
|
90 |
+
logits_big = outputs_big.logits[:, -1, :]
|
91 |
|
92 |
+
interpolated_logits = logits_big + (guidance_scale - 1) * (logits_big - logits_small)
|
93 |
|
94 |
if top_p < 1.0:
|
95 |
interpolated_logits = top_p_filtering(interpolated_logits, top_p=top_p)
|
|
|
102 |
break
|
103 |
|
104 |
generated_tokens.append(next_token.item())
|
105 |
+
current_input = next_token
|
106 |
+
|
107 |
+
# Update the cache with the latest past_key_values
|
108 |
+
cache_small = outputs_small.past_key_values
|
109 |
+
cache_big = outputs_big.past_key_values
|
110 |
|
111 |
partial_output = tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
112 |
yield partial_output
|