Spaces:
Paused
Paused
Commit
·
8fc450b
1
Parent(s):
9110acb
Update main.py
Browse files
main.py
CHANGED
@@ -23,7 +23,7 @@ tokenizer = AutoTokenizer.from_pretrained("Open-Orca/OpenOrca-Platypus2-13B", tr
|
|
23 |
def ask_bot(question):
|
24 |
input_ids = tokenizer.encode(question, return_tensors="pt").to(device)
|
25 |
with torch.no_grad():
|
26 |
-
output = model.generate(input_ids, max_length=
|
27 |
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
|
28 |
response = generated_text.split("->:")[-1]
|
29 |
return response
|
@@ -65,7 +65,7 @@ class CustomLLM(LLM):
|
|
65 |
|
66 |
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
|
67 |
with torch.no_grad():
|
68 |
-
output = model.generate(input_ids, max_length=
|
69 |
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
|
70 |
response = generated_text.split("->:")[-1]
|
71 |
return response
|
@@ -156,13 +156,13 @@ def chatbot(patient_id, user_data: dict=None):
|
|
156 |
human_input = prompt + user_input + " ->:"
|
157 |
human_text = user_input.replace("'", "")
|
158 |
response = llm._call(human_input)
|
159 |
-
response = response.replace("'", "")
|
160 |
-
memory.save_context({"input": user_input}, {"output": response})
|
161 |
-
summary = memory.load_memory_variables({})
|
162 |
-
ai_text = response.replace("'", "")
|
163 |
-
memory.save_context({"input": user_input}, {"output": ai_text})
|
164 |
-
summary = memory.load_memory_variables({})
|
165 |
-
db.insert(("patient_id", "patient_text", "ai_text", "timestamp", "summarized_text"), (patient_id, human_text, ai_text, str(datetime.now()), summary['history'].replace("'", "")))
|
166 |
db.close_db()
|
167 |
return {"response": response}
|
168 |
finally:
|
|
|
23 |
def ask_bot(question):
|
24 |
input_ids = tokenizer.encode(question, return_tensors="pt").to(device)
|
25 |
with torch.no_grad():
|
26 |
+
output = model.generate(input_ids, max_length=100, num_return_sequences=1, do_sample=True, top_k=50)
|
27 |
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
|
28 |
response = generated_text.split("->:")[-1]
|
29 |
return response
|
|
|
65 |
|
66 |
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
|
67 |
with torch.no_grad():
|
68 |
+
output = model.generate(input_ids, max_length=100, num_return_sequences=1, do_sample=True, top_k=50)
|
69 |
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
|
70 |
response = generated_text.split("->:")[-1]
|
71 |
return response
|
|
|
156 |
human_input = prompt + user_input + " ->:"
|
157 |
human_text = user_input.replace("'", "")
|
158 |
response = llm._call(human_input)
|
159 |
+
# response = response.replace("'", "")
|
160 |
+
# memory.save_context({"input": user_input}, {"output": response})
|
161 |
+
# summary = memory.load_memory_variables({})
|
162 |
+
# ai_text = response.replace("'", "")
|
163 |
+
# memory.save_context({"input": user_input}, {"output": ai_text})
|
164 |
+
# summary = memory.load_memory_variables({})
|
165 |
+
# db.insert(("patient_id", "patient_text", "ai_text", "timestamp", "summarized_text"), (patient_id, human_text, ai_text, str(datetime.now()), summary['history'].replace("'", "")))
|
166 |
db.close_db()
|
167 |
return {"response": response}
|
168 |
finally:
|