Spaces:
Runtime error
Runtime error
[update]edit main
Browse files
main.py
CHANGED
@@ -78,6 +78,9 @@ def main():
|
|
78 |
return_tensors="pt",
|
79 |
add_special_tokens=False,
|
80 |
).input_ids.to(args.device)
|
|
|
|
|
|
|
81 |
|
82 |
with torch.no_grad():
|
83 |
outputs = model.generate(
|
|
|
78 |
return_tensors="pt",
|
79 |
add_special_tokens=False,
|
80 |
).input_ids.to(args.device)
|
81 |
+
bos_token_id = torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long).to(args.device)
|
82 |
+
eos_token_id = torch.tensor([[tokenizer.eos_token_id]], dtype=torch.long).to(args.device)
|
83 |
+
input_ids = torch.concat([bos_token_id, input_ids, eos_token_id], dim=1)
|
84 |
|
85 |
with torch.no_grad():
|
86 |
outputs = model.generate(
|