Spaces:
Sleeping
Sleeping
Commit
·
0782370
1
Parent(s):
105856f
Add initial implementation of GenAI and related components for mental health assistance
Browse files- .env +1 -0
- app.py +93 -0
- requirements.txt +101 -0
- src/genai.py +69 -0
- src/semantic_searcher.py +41 -0
- src/upvote_predictor.py +71 -0
.env
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
GOOGLE_API_KEY = 'AIzaSyCuB81K0_G9wYnUkX9QF20bFSD7fEnTJ6k'
|
app.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
import pandas as pd
|
5 |
+
from datasets import load_dataset
|
6 |
+
|
7 |
+
from src.genai import GenAI
|
8 |
+
|
9 |
+
# from src.semantic_searcher import SemanticSearcher
|
10 |
+
# from src.upvote_predictor import UpvotePredictor
|
11 |
+
|
12 |
+
# Load the dataset
|
13 |
+
dataset_counsel_chat = load_dataset("nbertagnolli/counsel-chat")
|
14 |
+
df_counsel_chat = pd.DataFrame(dataset_counsel_chat["train"])
|
15 |
+
df_counsel_chat_topic = copy.deepcopy(
|
16 |
+
df_counsel_chat[
|
17 |
+
["questionID", "questionTitle", "questionText", "answerText", "topic"]
|
18 |
+
]
|
19 |
+
)
|
20 |
+
df_counsel_chat_topic["questionCombined"] = df_counsel_chat_topic.apply(
|
21 |
+
lambda x: (
|
22 |
+
f"QUESTION_TITLE: {x['questionTitle']}\nQUESTION_CONTEXT: {x['questionText']}"
|
23 |
+
),
|
24 |
+
axis=1,
|
25 |
+
)
|
26 |
+
df_counsel_chat_topic = df_counsel_chat_topic.drop_duplicates(
|
27 |
+
subset="questionID"
|
28 |
+
).reset_index(drop=True)
|
29 |
+
# list of unique topics
|
30 |
+
unique_topics = sorted(df_counsel_chat_topic["topic"].unique().tolist())
|
31 |
+
unique_topics = "\n".join(
|
32 |
+
[f"{idx+1}. {topic}" for idx, topic in enumerate(unique_topics)]
|
33 |
+
)
|
34 |
+
|
35 |
+
# few examples
|
36 |
+
few_examples = (
|
37 |
+
df_counsel_chat_topic.groupby("topic", as_index=False)[
|
38 |
+
["questionID", "questionCombined", "answerText", "topic"]
|
39 |
+
]
|
40 |
+
.apply(lambda s: s.sample(1))
|
41 |
+
.reset_index(drop=True)
|
42 |
+
)
|
43 |
+
few_examples["examples"] = few_examples.apply(
|
44 |
+
lambda x: (
|
45 |
+
f"{x['questionCombined']}\nTOPIC: {x['topic']}\nANSWER: {x['answerText']}"
|
46 |
+
),
|
47 |
+
axis=1,
|
48 |
+
)
|
49 |
+
examples = "\n".join(
|
50 |
+
f"<EXAMPLE {idx+1} start>\n{example}\n<EXAMPLE {idx+1} end>\n\n"
|
51 |
+
for idx, example in enumerate(few_examples["examples"].to_list())
|
52 |
+
)
|
53 |
+
|
54 |
+
# Initialize the SemanticSearcher
|
55 |
+
genai = GenAI()
|
56 |
+
# upvote_predictor = UpvotePredictor("models/bert_model")
|
57 |
+
# _ = SemanticSearcher(df_counsel_chat_topic)
|
58 |
+
|
59 |
+
|
60 |
+
def get_output(question: str, question_context: str = None) -> str:
|
61 |
+
answer, topic = genai.generate_content(
|
62 |
+
question, question_context, unique_topics, examples
|
63 |
+
)
|
64 |
+
return (answer, topic, "Yes", pd.DataFrame())
|
65 |
+
# upvote_prediction = upvote_predictor.get_upvote_prediction(
|
66 |
+
# question, answer, question_context
|
67 |
+
# )
|
68 |
+
# return (answer, topic, upvote_prediction[0], upvote_prediction[1])
|
69 |
+
|
70 |
+
|
71 |
+
demo = gr.Interface(
|
72 |
+
fn=get_output,
|
73 |
+
inputs=[
|
74 |
+
gr.Textbox(label="Input Question"),
|
75 |
+
gr.Textbox(label="(Optional) Additional Context for Question"),
|
76 |
+
],
|
77 |
+
outputs=[
|
78 |
+
gr.Textbox(label="GenAI based suggestion"),
|
79 |
+
gr.Textbox(label="Suggested Topic of Question"),
|
80 |
+
gr.Textbox(label="Is GenAI based suggestion credible?"),
|
81 |
+
gr.Dataframe(
|
82 |
+
label=(
|
83 |
+
"Semantically similar questions (and other metadata) to input question."
|
84 |
+
" Will be available if GenAI based suggestion is not credible."
|
85 |
+
)
|
86 |
+
),
|
87 |
+
],
|
88 |
+
)
|
89 |
+
|
90 |
+
demo.launch(debug=True)
|
91 |
+
# #input question
|
92 |
+
# input_question_context = "I'm going through some things with my feelings and myself. I barely sleep and I do nothing but think about how I'm worthless and how I shouldn't be here. I've never tried or contemplated suicide. I've always wanted to fix my issues, but I never get around to it. How can I change my feeling of being worthless to everyone?"
|
93 |
+
# input_question = "How can I change my feeling of being worthless to everyone?"
|
requirements.txt
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
-i https://pypi.org/simple
|
2 |
+
aiofiles==23.2.1; python_version >= '3.7'
|
3 |
+
aiohappyeyeballs==2.6.1; python_version >= '3.9'
|
4 |
+
aiohttp==3.11.13; python_version >= '3.9'
|
5 |
+
aiosignal==1.3.2; python_version >= '3.9'
|
6 |
+
annotated-types==0.7.0; python_version >= '3.8'
|
7 |
+
anyio==4.8.0; python_version >= '3.9'
|
8 |
+
attrs==25.3.0; python_version >= '3.8'
|
9 |
+
cachetools==5.5.2; python_version >= '3.7'
|
10 |
+
certifi==2025.1.31; python_version >= '3.6'
|
11 |
+
charset-normalizer==3.4.1; python_version >= '3.7'
|
12 |
+
click==8.1.8; python_version >= '3.7'
|
13 |
+
datasets==3.4.0; python_full_version >= '3.9.0'
|
14 |
+
dill==0.3.8; python_version >= '3.8'
|
15 |
+
distro==1.9.0; python_version >= '3.6'
|
16 |
+
fastapi==0.115.11; python_version >= '3.8'
|
17 |
+
ffmpy==0.5.0; python_version >= '3.8' and python_version < '4.0'
|
18 |
+
filelock==3.18.0; python_version >= '3.9'
|
19 |
+
frozenlist==1.5.0; python_version >= '3.8'
|
20 |
+
fsspec[http]==2024.12.0; python_version >= '3.8'
|
21 |
+
google-ai-generativelanguage==0.6.15; python_version >= '3.7'
|
22 |
+
google-api-core[grpc]==2.24.2; python_version >= '3.7'
|
23 |
+
google-api-python-client==2.164.0; python_version >= '3.7'
|
24 |
+
google-auth==2.38.0; python_version >= '3.7'
|
25 |
+
google-auth-httplib2==0.2.0
|
26 |
+
google-generativeai==0.8.4; python_version >= '3.9'
|
27 |
+
googleapis-common-protos==1.69.1; python_version >= '3.7'
|
28 |
+
gradio==5.21.0; python_version >= '3.10'
|
29 |
+
gradio-client==1.7.2; python_version >= '3.10'
|
30 |
+
groovy==0.1.2; python_version >= '3.10'
|
31 |
+
grpcio==1.71.0
|
32 |
+
grpcio-status==1.71.0
|
33 |
+
h11==0.14.0; python_version >= '3.7'
|
34 |
+
httpcore==1.0.7; python_version >= '3.8'
|
35 |
+
httplib2==0.22.0; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'
|
36 |
+
httpx==0.28.1; python_version >= '3.8'
|
37 |
+
huggingface-hub==0.29.3; python_full_version >= '3.8.0'
|
38 |
+
idna==3.10; python_version >= '3.6'
|
39 |
+
jinja2==3.1.6; python_version >= '3.7'
|
40 |
+
jiter==0.9.0; python_version >= '3.8'
|
41 |
+
joblib==1.4.2; python_version >= '3.8'
|
42 |
+
markdown-it-py==3.0.0; python_version >= '3.8'
|
43 |
+
markupsafe==2.1.5; python_version >= '3.7'
|
44 |
+
mdurl==0.1.2; python_version >= '3.7'
|
45 |
+
mpmath==1.3.0
|
46 |
+
multidict==6.1.0; python_version >= '3.8'
|
47 |
+
multiprocess==0.70.16; python_version >= '3.8'
|
48 |
+
networkx==3.4.2; python_version >= '3.10'
|
49 |
+
numpy==2.2.3; python_version >= '3.10'
|
50 |
+
openai==1.66.3; python_version >= '3.8'
|
51 |
+
orjson==3.10.15; python_version >= '3.8'
|
52 |
+
packaging==24.2; python_version >= '3.8'
|
53 |
+
pandas==2.2.3; python_version >= '3.9'
|
54 |
+
pillow==11.1.0; python_version >= '3.9'
|
55 |
+
propcache==0.3.0; python_version >= '3.9'
|
56 |
+
proto-plus==1.26.1; python_version >= '3.7'
|
57 |
+
protobuf==5.29.3; python_version >= '3.8'
|
58 |
+
pyarrow==19.0.1; python_version >= '3.9'
|
59 |
+
pyasn1==0.6.1; python_version >= '3.8'
|
60 |
+
pyasn1-modules==0.4.1; python_version >= '3.8'
|
61 |
+
pydantic==2.10.6; python_version >= '3.8'
|
62 |
+
pydantic-core==2.27.2; python_version >= '3.8'
|
63 |
+
pydub==0.25.1
|
64 |
+
pygments==2.19.1; python_version >= '3.8'
|
65 |
+
pyparsing==3.2.1; python_version > '3.0'
|
66 |
+
python-dateutil==2.9.0.post0; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2'
|
67 |
+
python-dotenv==1.0.1; python_version >= '3.8'
|
68 |
+
python-multipart==0.0.20; python_version >= '3.8'
|
69 |
+
pytz==2025.1
|
70 |
+
pyyaml==6.0.2; python_version >= '3.8'
|
71 |
+
regex==2024.11.6; python_version >= '3.8'
|
72 |
+
requests==2.32.3; python_version >= '3.8'
|
73 |
+
rich==13.9.4; python_full_version >= '3.8.0'
|
74 |
+
rsa==4.9; python_version >= '3.6' and python_version < '4'
|
75 |
+
ruff==0.11.0; sys_platform != 'emscripten'
|
76 |
+
safehttpx==0.1.6; python_version >= '3.10'
|
77 |
+
safetensors==0.5.3; python_version >= '3.7'
|
78 |
+
scikit-learn==1.6.1; python_version >= '3.9'
|
79 |
+
scipy==1.15.2; python_version >= '3.10'
|
80 |
+
semantic-version==2.10.0; python_version >= '2.7'
|
81 |
+
sentence-transformers==3.4.1; python_version >= '3.9'
|
82 |
+
shellingham==1.5.4; python_version >= '3.7'
|
83 |
+
six==1.17.0; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2'
|
84 |
+
sniffio==1.3.1; python_version >= '3.7'
|
85 |
+
starlette==0.46.1; sys_platform != 'emscripten'
|
86 |
+
sympy==1.13.3; python_version >= '3.8'
|
87 |
+
threadpoolctl==3.6.0; python_version >= '3.9'
|
88 |
+
tokenizers==0.21.1; python_version >= '3.9'
|
89 |
+
tomlkit==0.13.2; python_version >= '3.8'
|
90 |
+
torch==2.2.2; python_full_version >= '3.8.0'
|
91 |
+
tqdm==4.67.1; python_version >= '3.7'
|
92 |
+
transformers==4.49.0; python_full_version >= '3.9.0'
|
93 |
+
typer==0.15.2; sys_platform != 'emscripten'
|
94 |
+
typing-extensions==4.12.2; python_version >= '3.8'
|
95 |
+
tzdata==2025.1; python_version >= '2'
|
96 |
+
uritemplate==4.1.1; python_version >= '3.6'
|
97 |
+
urllib3==2.3.0; python_version >= '3.9'
|
98 |
+
uvicorn==0.34.0; sys_platform != 'emscripten'
|
99 |
+
websockets==15.0.1; python_version >= '3.9'
|
100 |
+
xxhash==3.5.0; python_version >= '3.7'
|
101 |
+
yarl==1.18.3; python_version >= '3.9'
|
src/genai.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
|
4 |
+
import google.generativeai as genai
|
5 |
+
from dotenv import load_dotenv
|
6 |
+
from pydantic import BaseModel
|
7 |
+
|
8 |
+
load_dotenv()
|
9 |
+
|
10 |
+
GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY")
|
11 |
+
genai.configure(api_key=GOOGLE_API_KEY)
|
12 |
+
|
13 |
+
|
14 |
+
class OutputSchema(BaseModel):
|
15 |
+
topic: str
|
16 |
+
answer: str
|
17 |
+
|
18 |
+
|
19 |
+
def valid_output(output: dict) -> dict:
|
20 |
+
return OutputSchema(**output)
|
21 |
+
|
22 |
+
|
23 |
+
class GenAI:
|
24 |
+
def __init__(self):
|
25 |
+
self.genai_model = genai.GenerativeModel("gemini-2.0-flash-lite")
|
26 |
+
|
27 |
+
def generate_content(
|
28 |
+
self,
|
29 |
+
question: str,
|
30 |
+
question_context: str,
|
31 |
+
unique_topics: str,
|
32 |
+
few_shot_examples: str,
|
33 |
+
):
|
34 |
+
prompt = f"""
|
35 |
+
INSTRUCTIONS:
|
36 |
+
1. You are a expert assistant to mental health counselors.
|
37 |
+
2. You are given following:
|
38 |
+
2.1 Input Question: A input question from mental health counselor seeking assistance.
|
39 |
+
2.2 Input Question Context: Additional context for input question. Can be None. If available, utilize this context in your response.
|
40 |
+
2.3 Topics: Categories of topics to which any input question may belong. Any input question will be categorized to one of these topics.
|
41 |
+
2.4 Few-shot examples: Some examples
|
42 |
+
3. Output the following:
|
43 |
+
3.1 A topic from list of topics for input question.
|
44 |
+
3.2 A precise answer for the input question.
|
45 |
+
- Length of answer should not exceed 256 words.
|
46 |
+
4. Your output MUST be a VALID JSON format.
|
47 |
+
4.1 Follow the sample JSON format given below.
|
48 |
+
|
49 |
+
INPUT:
|
50 |
+
Input Question: {question}
|
51 |
+
Input Question Context: {question_context}
|
52 |
+
|
53 |
+
Topics: {unique_topics}
|
54 |
+
|
55 |
+
Few-shot examples: {few_shot_examples}
|
56 |
+
|
57 |
+
Sample output JSON format: [{{
|
58 |
+
"topic": "A topic from list of topics for input question."
|
59 |
+
"answer": "An answer for the input question."
|
60 |
+
}}]
|
61 |
+
|
62 |
+
OUTPUT:"""
|
63 |
+
response = self.genai_model.generate_content(prompt)
|
64 |
+
out = response.text
|
65 |
+
out = out[7:-3]
|
66 |
+
out = json.loads(out)
|
67 |
+
valid_response = valid_output(out[0])
|
68 |
+
|
69 |
+
return (valid_response.answer, valid_response.topic)
|
src/semantic_searcher.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from sentence_transformers import SentenceTransformer
|
3 |
+
|
4 |
+
|
5 |
+
class SemanticSearcher:
|
6 |
+
def __init__(self, df_counsel_chat_topic):
|
7 |
+
self.df_counsel_chat_topic = df_counsel_chat_topic
|
8 |
+
self.embedder = SentenceTransformer("all-MiniLM-L6-v2")
|
9 |
+
self.question_embeddings = self.embedder.encode(
|
10 |
+
self.df_counsel_chat_topic["questionCombined"].tolist(),
|
11 |
+
show_progress_bar=True,
|
12 |
+
convert_to_tensor=True,
|
13 |
+
)
|
14 |
+
|
15 |
+
def retrieve_relevant_qna(self, question: str, question_context: str = None) -> str:
|
16 |
+
if question_context is None:
|
17 |
+
question_context = ""
|
18 |
+
query = question + "\n" + question_context
|
19 |
+
query_embedding = self.embedder.encode(query, convert_to_tensor=True)
|
20 |
+
|
21 |
+
# We use cosine-similarity and torch.topk to find the highest 5 scores
|
22 |
+
similarity_scores = self.embedder.similarity(
|
23 |
+
query_embedding, self.question_embeddings
|
24 |
+
)[0]
|
25 |
+
_, indices = torch.topk(similarity_scores, k=1)
|
26 |
+
index = indices.tolist()
|
27 |
+
question_id = self.df_counsel_chat_topic.loc[index, "questionID"].values[0]
|
28 |
+
relevant_qna = (
|
29 |
+
self.df_counsel_chat.loc[self.df_counsel_chat["questionID"] == question_id]
|
30 |
+
.sort_values(by=["upvotes", "views"], ascending=False)
|
31 |
+
.head(3)[[
|
32 |
+
"questionTitle",
|
33 |
+
"topic",
|
34 |
+
"therapistInfo",
|
35 |
+
"therapistURL",
|
36 |
+
"answerText",
|
37 |
+
"upvotes",
|
38 |
+
"views",
|
39 |
+
]]
|
40 |
+
)
|
41 |
+
return relevant_qna
|
src/upvote_predictor.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pandas as pd
|
3 |
+
import torch
|
4 |
+
from torch.utils.data import DataLoader, SequentialSampler, TensorDataset
|
5 |
+
from transformers import BertTokenizer
|
6 |
+
|
7 |
+
from src.semantic_searcher import retrieve_relevant_qna
|
8 |
+
|
9 |
+
|
10 |
+
class UpvotePredictor:
|
11 |
+
def __init__(self, model_path: str):
|
12 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
13 |
+
self.upvote_ml_model = torch.load(
|
14 |
+
model_path, map_location=torch.device("cpu"), weights_only=False
|
15 |
+
)
|
16 |
+
self.tokenizer = BertTokenizer.from_pretrained(
|
17 |
+
"bert-base-uncased", do_lower_case=True
|
18 |
+
)
|
19 |
+
self.upvote_ml_model.to(self.device)
|
20 |
+
self.upvote_ml_model.eval()
|
21 |
+
|
22 |
+
def get_upvote_prediction(
|
23 |
+
self, question: str, answer: str, question_context: str = None
|
24 |
+
) -> int:
|
25 |
+
llm_response_input_ids = []
|
26 |
+
llm_response_attention_masks = []
|
27 |
+
|
28 |
+
encoded_dict = self.tokenizer.encode_plus(
|
29 |
+
answer,
|
30 |
+
add_special_tokens=True,
|
31 |
+
max_length=256,
|
32 |
+
padding="max_length",
|
33 |
+
truncation=True,
|
34 |
+
return_attention_mask=True,
|
35 |
+
return_tensors="pt",
|
36 |
+
)
|
37 |
+
llm_response_input_ids.append(encoded_dict["input_ids"])
|
38 |
+
llm_response_attention_masks.append(encoded_dict["attention_mask"])
|
39 |
+
llm_response_input_ids = torch.cat(llm_response_input_ids, dim=0)
|
40 |
+
llm_response_attention_masks = torch.cat(llm_response_attention_masks, dim=0)
|
41 |
+
|
42 |
+
test_dataset = TensorDataset(
|
43 |
+
llm_response_input_ids, llm_response_attention_masks
|
44 |
+
)
|
45 |
+
test_dataloader = DataLoader(
|
46 |
+
test_dataset, # The validation samples.
|
47 |
+
sampler=SequentialSampler(test_dataset), # Pull out batches sequentially.
|
48 |
+
batch_size=1, # Evaluate with this batch size.
|
49 |
+
)
|
50 |
+
|
51 |
+
predictions = []
|
52 |
+
for batch in test_dataloader:
|
53 |
+
b_input_ids = batch[0].to(self.device)
|
54 |
+
b_input_mask = batch[1].to(self.device)
|
55 |
+
with torch.no_grad():
|
56 |
+
output = self.upvote_ml_model(
|
57 |
+
b_input_ids, token_type_ids=None, attention_mask=b_input_mask
|
58 |
+
)
|
59 |
+
logits = output.logits
|
60 |
+
logits = logits.detach().cpu().numpy()
|
61 |
+
pred_flat = np.argmax(logits, axis=1).flatten()
|
62 |
+
|
63 |
+
predictions.extend(list(pred_flat))
|
64 |
+
|
65 |
+
if predictions[0] == 0:
|
66 |
+
return (
|
67 |
+
"Not credible suggestion",
|
68 |
+
retrieve_relevant_qna(question, question_context),
|
69 |
+
)
|
70 |
+
else:
|
71 |
+
return ("Credible suggestion", pd.DataFrame())
|