Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -15,7 +15,6 @@ tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
|
|
15 |
# Load the Lora model
|
16 |
model = PeftModel.from_pretrained(model, peft_model_id)
|
17 |
|
18 |
-
|
19 |
def make_inference(synopsis):
|
20 |
batch = tokenizer(
|
21 |
f"Below is a one-sentence synopsis, please write a captivating short story based on this synopsis.\n\n### Synopsis:\n{synopsis}\n\n### Short Story:\n", return_tensors='pt',
|
@@ -24,7 +23,11 @@ def make_inference(synopsis):
|
|
24 |
with torch.cuda.amp.autocast():
|
25 |
output_tokens = model.generate(**batch, max_new_tokens=200)
|
26 |
|
27 |
-
|
|
|
|
|
|
|
|
|
28 |
|
29 |
|
30 |
if __name__ == "__main__":
|
|
|
15 |
# Load the Lora model
|
16 |
model = PeftModel.from_pretrained(model, peft_model_id)
|
17 |
|
|
|
18 |
def make_inference(synopsis):
|
19 |
batch = tokenizer(
|
20 |
f"Below is a one-sentence synopsis, please write a captivating short story based on this synopsis.\n\n### Synopsis:\n{synopsis}\n\n### Short Story:\n", return_tensors='pt',
|
|
|
23 |
with torch.cuda.amp.autocast():
|
24 |
output_tokens = model.generate(**batch, max_new_tokens=200)
|
25 |
|
26 |
+
full_output = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
|
27 |
+
short_story = full_output.split("### Short Story:\n")[-1].strip()
|
28 |
+
|
29 |
+
return short_story
|
30 |
+
|
31 |
|
32 |
|
33 |
if __name__ == "__main__":
|