gupta-amulya commited on
Commit
0782370
·
1 Parent(s): 105856f

Add initial implementation of GenAI and related components for mental health assistance

Browse files
Files changed (6) hide show
  1. .env +1 -0
  2. app.py +93 -0
  3. requirements.txt +101 -0
  4. src/genai.py +69 -0
  5. src/semantic_searcher.py +41 -0
  6. src/upvote_predictor.py +71 -0
.env ADDED
@@ -0,0 +1 @@
 
 
1
+ GOOGLE_API_KEY = 'AIzaSyCuB81K0_G9wYnUkX9QF20bFSD7fEnTJ6k'
app.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+
3
+ import gradio as gr
4
+ import pandas as pd
5
+ from datasets import load_dataset
6
+
7
+ from src.genai import GenAI
8
+
9
+ # from src.semantic_searcher import SemanticSearcher
10
+ # from src.upvote_predictor import UpvotePredictor
11
+
12
+ # Load the dataset
13
+ dataset_counsel_chat = load_dataset("nbertagnolli/counsel-chat")
14
+ df_counsel_chat = pd.DataFrame(dataset_counsel_chat["train"])
15
+ df_counsel_chat_topic = copy.deepcopy(
16
+ df_counsel_chat[
17
+ ["questionID", "questionTitle", "questionText", "answerText", "topic"]
18
+ ]
19
+ )
20
+ df_counsel_chat_topic["questionCombined"] = df_counsel_chat_topic.apply(
21
+ lambda x: (
22
+ f"QUESTION_TITLE: {x['questionTitle']}\nQUESTION_CONTEXT: {x['questionText']}"
23
+ ),
24
+ axis=1,
25
+ )
26
+ df_counsel_chat_topic = df_counsel_chat_topic.drop_duplicates(
27
+ subset="questionID"
28
+ ).reset_index(drop=True)
29
+ # list of unique topics
30
+ unique_topics = sorted(df_counsel_chat_topic["topic"].unique().tolist())
31
+ unique_topics = "\n".join(
32
+ [f"{idx+1}. {topic}" for idx, topic in enumerate(unique_topics)]
33
+ )
34
+
35
+ # few examples
36
+ few_examples = (
37
+ df_counsel_chat_topic.groupby("topic", as_index=False)[
38
+ ["questionID", "questionCombined", "answerText", "topic"]
39
+ ]
40
+ .apply(lambda s: s.sample(1))
41
+ .reset_index(drop=True)
42
+ )
43
+ few_examples["examples"] = few_examples.apply(
44
+ lambda x: (
45
+ f"{x['questionCombined']}\nTOPIC: {x['topic']}\nANSWER: {x['answerText']}"
46
+ ),
47
+ axis=1,
48
+ )
49
+ examples = "\n".join(
50
+ f"<EXAMPLE {idx+1} start>\n{example}\n<EXAMPLE {idx+1} end>\n\n"
51
+ for idx, example in enumerate(few_examples["examples"].to_list())
52
+ )
53
+
54
+ # Initialize the SemanticSearcher
55
+ genai = GenAI()
56
+ # upvote_predictor = UpvotePredictor("models/bert_model")
57
+ # _ = SemanticSearcher(df_counsel_chat_topic)
58
+
59
+
60
+ def get_output(question: str, question_context: str = None) -> str:
61
+ answer, topic = genai.generate_content(
62
+ question, question_context, unique_topics, examples
63
+ )
64
+ return (answer, topic, "Yes", pd.DataFrame())
65
+ # upvote_prediction = upvote_predictor.get_upvote_prediction(
66
+ # question, answer, question_context
67
+ # )
68
+ # return (answer, topic, upvote_prediction[0], upvote_prediction[1])
69
+
70
+
71
+ demo = gr.Interface(
72
+ fn=get_output,
73
+ inputs=[
74
+ gr.Textbox(label="Input Question"),
75
+ gr.Textbox(label="(Optional) Additional Context for Question"),
76
+ ],
77
+ outputs=[
78
+ gr.Textbox(label="GenAI based suggestion"),
79
+ gr.Textbox(label="Suggested Topic of Question"),
80
+ gr.Textbox(label="Is GenAI based suggestion credible?"),
81
+ gr.Dataframe(
82
+ label=(
83
+ "Semantically similar questions (and other metadata) to input question."
84
+ " Will be available if GenAI based suggestion is not credible."
85
+ )
86
+ ),
87
+ ],
88
+ )
89
+
90
+ demo.launch(debug=True)
91
+ # #input question
92
+ # input_question_context = "I'm going through some things with my feelings and myself. I barely sleep and I do nothing but think about how I'm worthless and how I shouldn't be here. I've never tried or contemplated suicide. I've always wanted to fix my issues, but I never get around to it. How can I change my feeling of being worthless to everyone?"
93
+ # input_question = "How can I change my feeling of being worthless to everyone?"
requirements.txt ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -i https://pypi.org/simple
2
+ aiofiles==23.2.1; python_version >= '3.7'
3
+ aiohappyeyeballs==2.6.1; python_version >= '3.9'
4
+ aiohttp==3.11.13; python_version >= '3.9'
5
+ aiosignal==1.3.2; python_version >= '3.9'
6
+ annotated-types==0.7.0; python_version >= '3.8'
7
+ anyio==4.8.0; python_version >= '3.9'
8
+ attrs==25.3.0; python_version >= '3.8'
9
+ cachetools==5.5.2; python_version >= '3.7'
10
+ certifi==2025.1.31; python_version >= '3.6'
11
+ charset-normalizer==3.4.1; python_version >= '3.7'
12
+ click==8.1.8; python_version >= '3.7'
13
+ datasets==3.4.0; python_full_version >= '3.9.0'
14
+ dill==0.3.8; python_version >= '3.8'
15
+ distro==1.9.0; python_version >= '3.6'
16
+ fastapi==0.115.11; python_version >= '3.8'
17
+ ffmpy==0.5.0; python_version >= '3.8' and python_version < '4.0'
18
+ filelock==3.18.0; python_version >= '3.9'
19
+ frozenlist==1.5.0; python_version >= '3.8'
20
+ fsspec[http]==2024.12.0; python_version >= '3.8'
21
+ google-ai-generativelanguage==0.6.15; python_version >= '3.7'
22
+ google-api-core[grpc]==2.24.2; python_version >= '3.7'
23
+ google-api-python-client==2.164.0; python_version >= '3.7'
24
+ google-auth==2.38.0; python_version >= '3.7'
25
+ google-auth-httplib2==0.2.0
26
+ google-generativeai==0.8.4; python_version >= '3.9'
27
+ googleapis-common-protos==1.69.1; python_version >= '3.7'
28
+ gradio==5.21.0; python_version >= '3.10'
29
+ gradio-client==1.7.2; python_version >= '3.10'
30
+ groovy==0.1.2; python_version >= '3.10'
31
+ grpcio==1.71.0
32
+ grpcio-status==1.71.0
33
+ h11==0.14.0; python_version >= '3.7'
34
+ httpcore==1.0.7; python_version >= '3.8'
35
+ httplib2==0.22.0; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'
36
+ httpx==0.28.1; python_version >= '3.8'
37
+ huggingface-hub==0.29.3; python_full_version >= '3.8.0'
38
+ idna==3.10; python_version >= '3.6'
39
+ jinja2==3.1.6; python_version >= '3.7'
40
+ jiter==0.9.0; python_version >= '3.8'
41
+ joblib==1.4.2; python_version >= '3.8'
42
+ markdown-it-py==3.0.0; python_version >= '3.8'
43
+ markupsafe==2.1.5; python_version >= '3.7'
44
+ mdurl==0.1.2; python_version >= '3.7'
45
+ mpmath==1.3.0
46
+ multidict==6.1.0; python_version >= '3.8'
47
+ multiprocess==0.70.16; python_version >= '3.8'
48
+ networkx==3.4.2; python_version >= '3.10'
49
+ numpy==2.2.3; python_version >= '3.10'
50
+ openai==1.66.3; python_version >= '3.8'
51
+ orjson==3.10.15; python_version >= '3.8'
52
+ packaging==24.2; python_version >= '3.8'
53
+ pandas==2.2.3; python_version >= '3.9'
54
+ pillow==11.1.0; python_version >= '3.9'
55
+ propcache==0.3.0; python_version >= '3.9'
56
+ proto-plus==1.26.1; python_version >= '3.7'
57
+ protobuf==5.29.3; python_version >= '3.8'
58
+ pyarrow==19.0.1; python_version >= '3.9'
59
+ pyasn1==0.6.1; python_version >= '3.8'
60
+ pyasn1-modules==0.4.1; python_version >= '3.8'
61
+ pydantic==2.10.6; python_version >= '3.8'
62
+ pydantic-core==2.27.2; python_version >= '3.8'
63
+ pydub==0.25.1
64
+ pygments==2.19.1; python_version >= '3.8'
65
+ pyparsing==3.2.1; python_version > '3.0'
66
+ python-dateutil==2.9.0.post0; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2'
67
+ python-dotenv==1.0.1; python_version >= '3.8'
68
+ python-multipart==0.0.20; python_version >= '3.8'
69
+ pytz==2025.1
70
+ pyyaml==6.0.2; python_version >= '3.8'
71
+ regex==2024.11.6; python_version >= '3.8'
72
+ requests==2.32.3; python_version >= '3.8'
73
+ rich==13.9.4; python_full_version >= '3.8.0'
74
+ rsa==4.9; python_version >= '3.6' and python_version < '4'
75
+ ruff==0.11.0; sys_platform != 'emscripten'
76
+ safehttpx==0.1.6; python_version >= '3.10'
77
+ safetensors==0.5.3; python_version >= '3.7'
78
+ scikit-learn==1.6.1; python_version >= '3.9'
79
+ scipy==1.15.2; python_version >= '3.10'
80
+ semantic-version==2.10.0; python_version >= '2.7'
81
+ sentence-transformers==3.4.1; python_version >= '3.9'
82
+ shellingham==1.5.4; python_version >= '3.7'
83
+ six==1.17.0; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2'
84
+ sniffio==1.3.1; python_version >= '3.7'
85
+ starlette==0.46.1; sys_platform != 'emscripten'
86
+ sympy==1.13.3; python_version >= '3.8'
87
+ threadpoolctl==3.6.0; python_version >= '3.9'
88
+ tokenizers==0.21.1; python_version >= '3.9'
89
+ tomlkit==0.13.2; python_version >= '3.8'
90
+ torch==2.2.2; python_full_version >= '3.8.0'
91
+ tqdm==4.67.1; python_version >= '3.7'
92
+ transformers==4.49.0; python_full_version >= '3.9.0'
93
+ typer==0.15.2; sys_platform != 'emscripten'
94
+ typing-extensions==4.12.2; python_version >= '3.8'
95
+ tzdata==2025.1; python_version >= '2'
96
+ uritemplate==4.1.1; python_version >= '3.6'
97
+ urllib3==2.3.0; python_version >= '3.9'
98
+ uvicorn==0.34.0; sys_platform != 'emscripten'
99
+ websockets==15.0.1; python_version >= '3.9'
100
+ xxhash==3.5.0; python_version >= '3.7'
101
+ yarl==1.18.3; python_version >= '3.9'
src/genai.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ import google.generativeai as genai
5
+ from dotenv import load_dotenv
6
+ from pydantic import BaseModel
7
+
8
+ load_dotenv()
9
+
10
+ GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY")
11
+ genai.configure(api_key=GOOGLE_API_KEY)
12
+
13
+
14
+ class OutputSchema(BaseModel):
15
+ topic: str
16
+ answer: str
17
+
18
+
19
+ def valid_output(output: dict) -> dict:
20
+ return OutputSchema(**output)
21
+
22
+
23
+ class GenAI:
24
+ def __init__(self):
25
+ self.genai_model = genai.GenerativeModel("gemini-2.0-flash-lite")
26
+
27
+ def generate_content(
28
+ self,
29
+ question: str,
30
+ question_context: str,
31
+ unique_topics: str,
32
+ few_shot_examples: str,
33
+ ):
34
+ prompt = f"""
35
+ INSTRUCTIONS:
36
+ 1. You are a expert assistant to mental health counselors.
37
+ 2. You are given following:
38
+ 2.1 Input Question: A input question from mental health counselor seeking assistance.
39
+ 2.2 Input Question Context: Additional context for input question. Can be None. If available, utilize this context in your response.
40
+ 2.3 Topics: Categories of topics to which any input question may belong. Any input question will be categorized to one of these topics.
41
+ 2.4 Few-shot examples: Some examples
42
+ 3. Output the following:
43
+ 3.1 A topic from list of topics for input question.
44
+ 3.2 A precise answer for the input question.
45
+ - Length of answer should not exceed 256 words.
46
+ 4. Your output MUST be a VALID JSON format.
47
+ 4.1 Follow the sample JSON format given below.
48
+
49
+ INPUT:
50
+ Input Question: {question}
51
+ Input Question Context: {question_context}
52
+
53
+ Topics: {unique_topics}
54
+
55
+ Few-shot examples: {few_shot_examples}
56
+
57
+ Sample output JSON format: [{{
58
+ "topic": "A topic from list of topics for input question."
59
+ "answer": "An answer for the input question."
60
+ }}]
61
+
62
+ OUTPUT:"""
63
+ response = self.genai_model.generate_content(prompt)
64
+ out = response.text
65
+ out = out[7:-3]
66
+ out = json.loads(out)
67
+ valid_response = valid_output(out[0])
68
+
69
+ return (valid_response.answer, valid_response.topic)
src/semantic_searcher.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from sentence_transformers import SentenceTransformer
3
+
4
+
5
+ class SemanticSearcher:
6
+ def __init__(self, df_counsel_chat_topic):
7
+ self.df_counsel_chat_topic = df_counsel_chat_topic
8
+ self.embedder = SentenceTransformer("all-MiniLM-L6-v2")
9
+ self.question_embeddings = self.embedder.encode(
10
+ self.df_counsel_chat_topic["questionCombined"].tolist(),
11
+ show_progress_bar=True,
12
+ convert_to_tensor=True,
13
+ )
14
+
15
+ def retrieve_relevant_qna(self, question: str, question_context: str = None) -> str:
16
+ if question_context is None:
17
+ question_context = ""
18
+ query = question + "\n" + question_context
19
+ query_embedding = self.embedder.encode(query, convert_to_tensor=True)
20
+
21
+ # We use cosine-similarity and torch.topk to find the highest 5 scores
22
+ similarity_scores = self.embedder.similarity(
23
+ query_embedding, self.question_embeddings
24
+ )[0]
25
+ _, indices = torch.topk(similarity_scores, k=1)
26
+ index = indices.tolist()
27
+ question_id = self.df_counsel_chat_topic.loc[index, "questionID"].values[0]
28
+ relevant_qna = (
29
+ self.df_counsel_chat.loc[self.df_counsel_chat["questionID"] == question_id]
30
+ .sort_values(by=["upvotes", "views"], ascending=False)
31
+ .head(3)[[
32
+ "questionTitle",
33
+ "topic",
34
+ "therapistInfo",
35
+ "therapistURL",
36
+ "answerText",
37
+ "upvotes",
38
+ "views",
39
+ ]]
40
+ )
41
+ return relevant_qna
src/upvote_predictor.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import torch
4
+ from torch.utils.data import DataLoader, SequentialSampler, TensorDataset
5
+ from transformers import BertTokenizer
6
+
7
+ from src.semantic_searcher import retrieve_relevant_qna
8
+
9
+
10
+ class UpvotePredictor:
11
+ def __init__(self, model_path: str):
12
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ self.upvote_ml_model = torch.load(
14
+ model_path, map_location=torch.device("cpu"), weights_only=False
15
+ )
16
+ self.tokenizer = BertTokenizer.from_pretrained(
17
+ "bert-base-uncased", do_lower_case=True
18
+ )
19
+ self.upvote_ml_model.to(self.device)
20
+ self.upvote_ml_model.eval()
21
+
22
+ def get_upvote_prediction(
23
+ self, question: str, answer: str, question_context: str = None
24
+ ) -> int:
25
+ llm_response_input_ids = []
26
+ llm_response_attention_masks = []
27
+
28
+ encoded_dict = self.tokenizer.encode_plus(
29
+ answer,
30
+ add_special_tokens=True,
31
+ max_length=256,
32
+ padding="max_length",
33
+ truncation=True,
34
+ return_attention_mask=True,
35
+ return_tensors="pt",
36
+ )
37
+ llm_response_input_ids.append(encoded_dict["input_ids"])
38
+ llm_response_attention_masks.append(encoded_dict["attention_mask"])
39
+ llm_response_input_ids = torch.cat(llm_response_input_ids, dim=0)
40
+ llm_response_attention_masks = torch.cat(llm_response_attention_masks, dim=0)
41
+
42
+ test_dataset = TensorDataset(
43
+ llm_response_input_ids, llm_response_attention_masks
44
+ )
45
+ test_dataloader = DataLoader(
46
+ test_dataset, # The validation samples.
47
+ sampler=SequentialSampler(test_dataset), # Pull out batches sequentially.
48
+ batch_size=1, # Evaluate with this batch size.
49
+ )
50
+
51
+ predictions = []
52
+ for batch in test_dataloader:
53
+ b_input_ids = batch[0].to(self.device)
54
+ b_input_mask = batch[1].to(self.device)
55
+ with torch.no_grad():
56
+ output = self.upvote_ml_model(
57
+ b_input_ids, token_type_ids=None, attention_mask=b_input_mask
58
+ )
59
+ logits = output.logits
60
+ logits = logits.detach().cpu().numpy()
61
+ pred_flat = np.argmax(logits, axis=1).flatten()
62
+
63
+ predictions.extend(list(pred_flat))
64
+
65
+ if predictions[0] == 0:
66
+ return (
67
+ "Not credible suggestion",
68
+ retrieve_relevant_qna(question, question_context),
69
+ )
70
+ else:
71
+ return ("Credible suggestion", pd.DataFrame())