Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -66,7 +66,8 @@ def generate(
|
|
66 |
generation_kwargs = dict(input_ids=torch.tensor([input_ids]).to(llm.device), do_sample=True,
|
67 |
max_new_tokens=512, temperature=0.5, top_p=0.85, top_k=50, repetition_penalty=1.05)
|
68 |
llm_result = llm.generate(**generation_kwargs)
|
69 |
-
llm_result =
|
|
|
70 |
print(llm_result)
|
71 |
expanded_description = json.loads(llm_result)["expanded_description"]
|
72 |
print(expanded_description)
|
|
|
66 |
generation_kwargs = dict(input_ids=torch.tensor([input_ids]).to(llm.device), do_sample=True,
|
67 |
max_new_tokens=512, temperature=0.5, top_p=0.85, top_k=50, repetition_penalty=1.05)
|
68 |
llm_result = llm.generate(**generation_kwargs)
|
69 |
+
llm_result = llm_result.cpu()[0][len(input_ids):]
|
70 |
+
llm_result = BOT_PREFIX + tokenizer.decode(llm_result, skip_special_tokens=True)
|
71 |
print(llm_result)
|
72 |
expanded_description = json.loads(llm_result)["expanded_description"]
|
73 |
print(expanded_description)
|