mdacampora commited on
Commit
c8e23a7
·
1 Parent(s): 498a9d4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -23
app.py CHANGED
@@ -15,35 +15,35 @@ tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
15
  # Load the Lora model
16
  model = PeftModel.from_pretrained(model, peft_model_id)
17
 
18
- def make_inference(problem):
19
- batch = tokenizer(
20
- problem,
21
- return_tensors="pt",
22
- )
23
 
24
- with torch.cuda.amp.autocast():
25
- output_tokens = model.generate(**batch, max_new_tokens=50)
26
 
27
 
28
 
29
 
30
 
31
- # def make_inference(conversation):
32
- # conversation_history = conversation
33
- # response = ""
34
- # while True:
35
- # batch = tokenizer(
36
- # f"### Problem:\n{conversation_history}\n{response}",
37
- # return_tensors="pt",
38
- # )
39
- # with torch.cuda.amp.autocast():
40
- # output_tokens = model.generate(**batch, max_new_tokens=50)
41
- # new_response = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
42
- # if new_response.strip() == "":
43
- # break
44
- # response = f"\n{new_response}"
45
- # conversation_history += response
46
- # return conversation_history
47
 
48
 
49
  if __name__ == "__main__":
 
15
  # Load the Lora model
16
  model = PeftModel.from_pretrained(model, peft_model_id)
17
 
18
+ # def make_inference(problem):
19
+ # batch = tokenizer(
20
+ # problem,
21
+ # return_tensors="pt",
22
+ # )
23
 
24
+ # with torch.cuda.amp.autocast():
25
+ # output_tokens = model.generate(**batch, max_new_tokens=50)
26
 
27
 
28
 
29
 
30
 
31
+ def make_inference(conversation):
32
+ conversation_history = conversation
33
+ response = ""
34
+ while True:
35
+ batch = tokenizer(
36
+ f"### Problem:\n{conversation_history}\n{response}",
37
+ return_tensors="pt",
38
+ )
39
+ with torch.cuda.amp.autocast():
40
+ output_tokens = model.generate(**batch, max_new_tokens=50)
41
+ new_response = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
42
+ if new_response.strip() == "":
43
+ break
44
+ response = f"\n{new_response}"
45
+ conversation_history += response
46
+ return conversation_history
47
 
48
 
49
  if __name__ == "__main__":