Spaces:
Sleeping
Sleeping
ASledziewska
commited on
Commit
·
1872b66
1
Parent(s):
5038c7a
Update llm_response_generator.py
Browse files- llm_response_generator.py +29 -18
llm_response_generator.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
import os
|
2 |
from langchain_community.llms import HuggingFaceHub
|
3 |
from langchain_community.llms import OpenAI
|
4 |
-
# from langchain.llms import HuggingFaceHub, OpenAI
|
5 |
from langchain.chains import LLMChain
|
6 |
from langchain.prompts import PromptTemplate
|
7 |
import warnings
|
@@ -9,17 +8,19 @@ import warnings
|
|
9 |
warnings.filterwarnings("ignore")
|
10 |
|
11 |
class LLLResponseGenerator():
|
12 |
-
|
13 |
def __init__(self):
|
14 |
-
|
|
|
15 |
|
|
|
|
|
|
|
16 |
|
17 |
def llm_inference(
|
18 |
self,
|
19 |
model_type: str,
|
20 |
question: str,
|
21 |
prompt_template: str,
|
22 |
-
context: str,
|
23 |
ai_tone: str,
|
24 |
questionnaire: str,
|
25 |
user_text: str,
|
@@ -37,7 +38,6 @@ class LLLResponseGenerator():
|
|
37 |
model_str: Denotes the LLM vendor's name. Can be either 'huggingface' or 'openai'
|
38 |
question: The question to be asked to the LLM.
|
39 |
prompt_template: The prompt template itself.
|
40 |
-
context: Instructions for the LLM.
|
41 |
ai_tone: Can be either empathy, encouragement or suggest medical help.
|
42 |
questionnaire: Can be either depression, anxiety or adhd.
|
43 |
user_text: Response given by the user.
|
@@ -65,13 +65,12 @@ class LLLResponseGenerator():
|
|
65 |
)
|
66 |
|
67 |
if model_type == "openai":
|
68 |
-
# https://api.python.langchain.com/en/stable/llms/langchain.llms.openai.OpenAI.html#langchain.llms.openai.OpenAI
|
69 |
llm = OpenAI(
|
70 |
model_name=openai_model_name, temperature=temperature, max_tokens=max_length
|
71 |
)
|
72 |
llm_chain = LLMChain(prompt=prompt, llm=llm)
|
73 |
return llm_chain.run(
|
74 |
-
context=context,
|
75 |
ai_tone=ai_tone,
|
76 |
questionnaire=questionnaire,
|
77 |
question=question,
|
@@ -79,15 +78,14 @@ class LLLResponseGenerator():
|
|
79 |
)
|
80 |
|
81 |
elif model_type == "huggingface":
|
82 |
-
# https://python.langchain.com/docs/integrations/llms/huggingface_hub
|
83 |
llm = HuggingFaceHub(
|
84 |
repo_id=hf_repo_id,
|
85 |
model_kwargs={"temperature": temperature, "max_length": max_length},
|
86 |
)
|
87 |
|
88 |
llm_chain = LLMChain(prompt=prompt, llm=llm)
|
89 |
-
response =
|
90 |
-
context=context,
|
91 |
ai_tone=ai_tone,
|
92 |
questionnaire=questionnaire,
|
93 |
question=question,
|
@@ -108,8 +106,6 @@ if __name__ == "__main__":
|
|
108 |
# Please ensure you have a .env file available with 'HUGGINGFACEHUB_API_TOKEN' and 'OPENAI_API_KEY' values.
|
109 |
HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN')
|
110 |
|
111 |
-
context = "You are a mental health supporting non-medical assistant. DO NOT PROVIDE any medical advice with conviction."
|
112 |
-
|
113 |
ai_tone = "EMPATHY"
|
114 |
questionnaire = "ADHD"
|
115 |
question = (
|
@@ -136,17 +132,32 @@ if __name__ == "__main__":
|
|
136 |
|
137 |
model = LLLResponseGenerator()
|
138 |
|
139 |
-
|
140 |
-
|
141 |
model_type="huggingface",
|
142 |
question=question,
|
143 |
prompt_template=template,
|
144 |
-
context=context,
|
145 |
ai_tone=ai_tone,
|
146 |
questionnaire=questionnaire,
|
147 |
user_text=user_text,
|
148 |
temperature=temperature,
|
149 |
max_length=max_length,
|
150 |
-
)
|
151 |
-
|
152 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
from langchain_community.llms import HuggingFaceHub
|
3 |
from langchain_community.llms import OpenAI
|
|
|
4 |
from langchain.chains import LLMChain
|
5 |
from langchain.prompts import PromptTemplate
|
6 |
import warnings
|
|
|
8 |
warnings.filterwarnings("ignore")
|
9 |
|
10 |
class LLLResponseGenerator():
|
|
|
11 |
def __init__(self):
|
12 |
+
self.context = "You are a mental health supporting non-medical assistant. DO NOT PROVIDE any medical advice with conviction."
|
13 |
+
self.conversation_history = []
|
14 |
|
15 |
+
def update_context(self, user_text):
|
16 |
+
self.conversation_history.append(user_text)
|
17 |
+
self.context = "\n".join(self.conversation_history)
|
18 |
|
19 |
def llm_inference(
|
20 |
self,
|
21 |
model_type: str,
|
22 |
question: str,
|
23 |
prompt_template: str,
|
|
|
24 |
ai_tone: str,
|
25 |
questionnaire: str,
|
26 |
user_text: str,
|
|
|
38 |
model_str: Denotes the LLM vendor's name. Can be either 'huggingface' or 'openai'
|
39 |
question: The question to be asked to the LLM.
|
40 |
prompt_template: The prompt template itself.
|
|
|
41 |
ai_tone: Can be either empathy, encouragement or suggest medical help.
|
42 |
questionnaire: Can be either depression, anxiety or adhd.
|
43 |
user_text: Response given by the user.
|
|
|
65 |
)
|
66 |
|
67 |
if model_type == "openai":
|
|
|
68 |
llm = OpenAI(
|
69 |
model_name=openai_model_name, temperature=temperature, max_tokens=max_length
|
70 |
)
|
71 |
llm_chain = LLMChain(prompt=prompt, llm=llm)
|
72 |
return llm_chain.run(
|
73 |
+
context=self.context,
|
74 |
ai_tone=ai_tone,
|
75 |
questionnaire=questionnaire,
|
76 |
question=question,
|
|
|
78 |
)
|
79 |
|
80 |
elif model_type == "huggingface":
|
|
|
81 |
llm = HuggingFaceHub(
|
82 |
repo_id=hf_repo_id,
|
83 |
model_kwargs={"temperature": temperature, "max_length": max_length},
|
84 |
)
|
85 |
|
86 |
llm_chain = LLMChain(prompt=prompt, llm=llm)
|
87 |
+
response = llm_chain.run(
|
88 |
+
context=self.context,
|
89 |
ai_tone=ai_tone,
|
90 |
questionnaire=questionnaire,
|
91 |
question=question,
|
|
|
106 |
# Please ensure you have a .env file available with 'HUGGINGFACEHUB_API_TOKEN' and 'OPENAI_API_KEY' values.
|
107 |
HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN')
|
108 |
|
|
|
|
|
109 |
ai_tone = "EMPATHY"
|
110 |
questionnaire = "ADHD"
|
111 |
question = (
|
|
|
132 |
|
133 |
model = LLLResponseGenerator()
|
134 |
|
135 |
+
# Initial prompt
|
136 |
+
print("Bot:", model.llm_inference(
|
137 |
model_type="huggingface",
|
138 |
question=question,
|
139 |
prompt_template=template,
|
|
|
140 |
ai_tone=ai_tone,
|
141 |
questionnaire=questionnaire,
|
142 |
user_text=user_text,
|
143 |
temperature=temperature,
|
144 |
max_length=max_length,
|
145 |
+
))
|
146 |
+
|
147 |
+
while True:
|
148 |
+
user_input = input("You: ")
|
149 |
+
if user_input.lower() == "exit":
|
150 |
+
break
|
151 |
+
|
152 |
+
model.update_context(user_input)
|
153 |
+
|
154 |
+
print("Bot:", model.llm_inference(
|
155 |
+
model_type="huggingface",
|
156 |
+
question=question,
|
157 |
+
prompt_template=template,
|
158 |
+
ai_tone=ai_tone,
|
159 |
+
questionnaire=questionnaire,
|
160 |
+
user_text=user_input,
|
161 |
+
temperature=temperature,
|
162 |
+
max_length=max_length,
|
163 |
+
))
|