Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
@@ -14,12 +14,12 @@ model = T5ForConditionalGeneration.from_pretrained("roborovski/superprompt-v1",
|
|
14 |
|
15 |
|
16 |
def generate(
|
17 |
-
prompt, history, max_new_tokens=
|
18 |
):
|
19 |
|
20 |
input_text = f"{prompt}, {history}"
|
21 |
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
|
22 |
-
outputs = model.generate(input_ids, max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty, temperature=temperature, top_p=top_p, top_k=top_k)
|
23 |
better_prompt = tokenizer.decode(outputs[0])
|
24 |
return better_prompt
|
25 |
|
|
|
14 |
|
15 |
|
16 |
def generate(
|
17 |
+
prompt, history, max_new_tokens=512, repetition_penalty=1.2, temperature=0.5, top_p=1, top_k=1
|
18 |
):
|
19 |
|
20 |
input_text = f"{prompt}, {history}"
|
21 |
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
|
22 |
+
outputs = model.generate(input_ids, max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty, do_sample=True, temperature=temperature, top_p=top_p, top_k=top_k)
|
23 |
better_prompt = tokenizer.decode(outputs[0])
|
24 |
return better_prompt
|
25 |
|