mdacampora commited on
Commit
61da192
·
1 Parent(s): 976e5ed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -23
app.py CHANGED
@@ -15,35 +15,35 @@ tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
15
  # Load the Lora model
16
  model = PeftModel.from_pretrained(model, peft_model_id)
17
 
18
- # def make_inference(problem):
19
- # batch = tokenizer(
20
- # problem,
21
- # return_tensors="pt",
22
- # )
23
 
24
- # with torch.cuda.amp.autocast():
25
- # output_tokens = model.generate(**batch, max_new_tokens=50)
26
 
27
 
28
 
29
 
30
 
31
- def make_inference(conversation, response):
32
- conversation_history = conversation
33
- response = ""
34
- while True:
35
- batch = tokenizer(
36
- f"### Problem:\n{conversation_history}\n{response}",
37
- return_tensors="pt",
38
- )
39
- with torch.cuda.amp.autocast():
40
- output_tokens = model.generate(**batch, max_new_tokens=50)
41
- new_response = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
42
- if new_response.strip() == "":
43
- break
44
- response = f"\n{new_response}"
45
- conversation_history += response
46
- return conversation_history
47
 
48
 
49
  if __name__ == "__main__":
 
15
  # Load the Lora model
16
  model = PeftModel.from_pretrained(model, peft_model_id)
17
 
18
+ def make_inference(problem, answer):
19
+ batch = tokenizer(
20
+ problem,
21
+ return_tensors="pt",
22
+ )
23
 
24
+ with torch.cuda.amp.autocast():
25
+ output_tokens = model.generate(**batch, max_new_tokens=50)
26
 
27
 
28
 
29
 
30
 
31
+ # def make_inference(conversation, response):
32
+ # conversation_history = conversation
33
+ # response = ""
34
+ # while True:
35
+ # batch = tokenizer(
36
+ # f"### Problem:\n{conversation_history}\n{response}",
37
+ # return_tensors="pt",
38
+ # )
39
+ # with torch.cuda.amp.autocast():
40
+ # output_tokens = model.generate(**batch, max_new_tokens=50)
41
+ # new_response = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
42
+ # if new_response.strip() == "":
43
+ # break
44
+ # response = f"\n{new_response}"
45
+ # conversation_history += response
46
+ # return conversation_history
47
 
48
 
49
  if __name__ == "__main__":