jaydenccc commited on
Commit
8b070e9
·
1 Parent(s): a4439c6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -2
app.py CHANGED
@@ -15,7 +15,6 @@ tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
15
  # Load the Lora model
16
  model = PeftModel.from_pretrained(model, peft_model_id)
17
 
18
-
19
  def make_inference(synopsis):
20
  batch = tokenizer(
21
  f"Below is a one-sentence synopsis, please write a captivating short story based on this synopsis.\n\n### Synopsis:\n{synopsis}\n\n### Short Story:\n", return_tensors='pt',
@@ -24,7 +23,11 @@ def make_inference(synopsis):
24
  with torch.cuda.amp.autocast():
25
  output_tokens = model.generate(**batch, max_new_tokens=200)
26
 
27
- return tokenizer.decode(output_tokens[0], skip_special_tokens=True)
 
 
 
 
28
 
29
 
30
  if __name__ == "__main__":
 
15
  # Load the Lora model
16
  model = PeftModel.from_pretrained(model, peft_model_id)
17
 
 
18
  def make_inference(synopsis):
19
  batch = tokenizer(
20
  f"Below is a one-sentence synopsis, please write a captivating short story based on this synopsis.\n\n### Synopsis:\n{synopsis}\n\n### Short Story:\n", return_tensors='pt',
 
23
  with torch.cuda.amp.autocast():
24
  output_tokens = model.generate(**batch, max_new_tokens=200)
25
 
26
+ full_output = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
27
+ short_story = full_output.split("### Short Story:\n")[-1].strip()
28
+
29
+ return short_story
30
+
31
 
32
 
33
  if __name__ == "__main__":