mdacampora commited on
Commit
acb3080
1 Parent(s): 67a414c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -12
app.py CHANGED
@@ -17,17 +17,22 @@ model = PeftModel.from_pretrained(model, peft_model_id)
17
 
18
 
19
 
20
- def make_inference(conversations):
21
- context = ""
22
- for conversation in conversations:
23
- context += f"{conversation}\n\n"
24
- prompt = f"### Conversation:\n{context}"
25
- batch = tokenizer(prompt, return_tensors="pt")
26
- with torch.cuda.amp.autocast():
27
- output_tokens = model.generate(**batch, max_new_tokens=50)
28
- response = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
29
- updated_conversation = f"{context}\n{response}"
30
- return updated_conversation
 
 
 
 
 
31
 
32
 
33
  if __name__ == "__main__":
@@ -52,5 +57,5 @@ gr.Interface(
52
  ],
53
  gr.outputs.Textbox(label="Updated Conversation"),
54
  title="tax-convos-demo",
55
- description="Trying to create a crude chat bot for tax services.",
56
  ).launch()
 
17
 
18
 
19
 
20
+ def make_inference(conversation):
21
+ conversation_history = conversation
22
+ response = ""
23
+ while True:
24
+ batch = tokenizer(
25
+ f"### Problem:\n{conversation_history}\n{response}",
26
+ return_tensors="pt",
27
+ )
28
+ with torch.cuda.amp.autocast():
29
+ output_tokens = model.generate(**batch, max_new_tokens=50)
30
+ new_response = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
31
+ if new_response.strip() == "":
32
+ break
33
+ response = f"\n{new_response}"
34
+ conversation_history += response
35
+ return conversation_history
36
 
37
 
38
  if __name__ == "__main__":
 
57
  ],
58
  gr.outputs.Textbox(label="Updated Conversation"),
59
  title="tax-convos-demo",
60
+ description="Ask any tax-related questions you may have.",
61
  ).launch()