legacydemo / src /semantic_searcher.py
gupta-amulya's picture
Enhance SemanticSearcher integration and refine UpvotePredictor output handling
20df6e4
import torch
from sentence_transformers import SentenceTransformer
class SemanticSearcher:
def __init__(self, df_counsel_chat_topic, df_counsel_chat):
self.df_counsel_chat_topic = df_counsel_chat_topic
self.df_counsel_chat = df_counsel_chat
self.embedder = SentenceTransformer("all-MiniLM-L6-v2")
self.question_embeddings = self.embedder.encode(
self.df_counsel_chat_topic["questionCombined"].tolist(),
show_progress_bar=True,
convert_to_tensor=True,
)
def retrieve_relevant_qna(self, question: str, question_context: str = None):
if question_context is None:
question_context = ""
query = question + "\n" + question_context
query_embedding = self.embedder.encode(query, convert_to_tensor=True)
# We use cosine-similarity and torch.topk to find the highest 5 scores
similarity_scores = self.embedder.similarity(
query_embedding, self.question_embeddings
)[0]
_, indices = torch.topk(similarity_scores, k=1)
index = indices.tolist()
question_id = self.df_counsel_chat_topic.loc[index, "questionID"].values[0]
relevant_qna = (
self.df_counsel_chat.loc[self.df_counsel_chat["questionID"] == question_id]
.sort_values(by=["upvotes", "views"], ascending=False)
.head(3)[[
"questionTitle",
"topic",
"therapistInfo",
"therapistURL",
"answerText",
"upvotes",
"views",
]]
)
return relevant_qna