Spaces:
Runtime error
Runtime error
Josh Nguyen
commited on
Commit
•
d38f5f1
1
Parent(s):
422252e
Fix a bug in generate_text
Browse files
app.py
CHANGED
@@ -34,12 +34,10 @@ def generate_text(prompt: str,
|
|
34 |
temperature: float = 0.5,
|
35 |
top_p: float = 0.95,
|
36 |
top_k: int = 50) -> str:
|
37 |
-
|
38 |
# Encode the prompt
|
39 |
inputs = tokenizer([prompt],
|
40 |
return_tensors='pt',
|
41 |
add_special_tokens=False).to(DEVICE)
|
42 |
-
|
43 |
# Prepare arguments for generation
|
44 |
input_length = inputs["input_ids"].shape[-1]
|
45 |
max_new_tokens = min(max_new_tokens, WINDOW_SIZE - input_length)
|
@@ -56,8 +54,8 @@ def generate_text(prompt: str,
|
|
56 |
skip_prompt=True,
|
57 |
skip_special_tokens=True)
|
58 |
generation_kwargs = dict(
|
59 |
-
inputs
|
60 |
-
streamer=
|
61 |
max_new_tokens=max_new_tokens,
|
62 |
do_sample=True,
|
63 |
top_p=top_p,
|
@@ -65,12 +63,10 @@ def generate_text(prompt: str,
|
|
65 |
temperature=temperature,
|
66 |
num_beams=1,
|
67 |
)
|
68 |
-
|
69 |
# Generate text
|
70 |
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
71 |
thread.start()
|
72 |
-
|
73 |
-
generated_text = ""
|
74 |
for new_text in streamer:
|
75 |
generated_text += new_text
|
76 |
return generated_text
|
|
|
34 |
temperature: float = 0.5,
|
35 |
top_p: float = 0.95,
|
36 |
top_k: int = 50) -> str:
|
|
|
37 |
# Encode the prompt
|
38 |
inputs = tokenizer([prompt],
|
39 |
return_tensors='pt',
|
40 |
add_special_tokens=False).to(DEVICE)
|
|
|
41 |
# Prepare arguments for generation
|
42 |
input_length = inputs["input_ids"].shape[-1]
|
43 |
max_new_tokens = min(max_new_tokens, WINDOW_SIZE - input_length)
|
|
|
54 |
skip_prompt=True,
|
55 |
skip_special_tokens=True)
|
56 |
generation_kwargs = dict(
|
57 |
+
**inputs,
|
58 |
+
streamer=streamer,
|
59 |
max_new_tokens=max_new_tokens,
|
60 |
do_sample=True,
|
61 |
top_p=top_p,
|
|
|
63 |
temperature=temperature,
|
64 |
num_beams=1,
|
65 |
)
|
|
|
66 |
# Generate text
|
67 |
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
68 |
thread.start()
|
69 |
+
generated_text = prompt
|
|
|
70 |
for new_text in streamer:
|
71 |
generated_text += new_text
|
72 |
return generated_text
|