Spaces:
Sleeping
Sleeping
import torch | |
from sentence_transformers import SentenceTransformer | |
class SemanticSearcher: | |
def __init__(self, df_counsel_chat_topic, df_counsel_chat): | |
self.df_counsel_chat_topic = df_counsel_chat_topic | |
self.df_counsel_chat = df_counsel_chat | |
self.embedder = SentenceTransformer("all-MiniLM-L6-v2") | |
self.question_embeddings = self.embedder.encode( | |
self.df_counsel_chat_topic["questionCombined"].tolist(), | |
show_progress_bar=True, | |
convert_to_tensor=True, | |
) | |
def retrieve_relevant_qna(self, question: str, question_context: str = None): | |
if question_context is None: | |
question_context = "" | |
query = question + "\n" + question_context | |
query_embedding = self.embedder.encode(query, convert_to_tensor=True) | |
# We use cosine-similarity and torch.topk to find the highest 5 scores | |
similarity_scores = self.embedder.similarity( | |
query_embedding, self.question_embeddings | |
)[0] | |
_, indices = torch.topk(similarity_scores, k=1) | |
index = indices.tolist() | |
question_id = self.df_counsel_chat_topic.loc[index, "questionID"].values[0] | |
relevant_qna = ( | |
self.df_counsel_chat.loc[self.df_counsel_chat["questionID"] == question_id] | |
.sort_values(by=["upvotes", "views"], ascending=False) | |
.head(3)[[ | |
"questionTitle", | |
"topic", | |
"therapistInfo", | |
"therapistURL", | |
"answerText", | |
"upvotes", | |
"views", | |
]] | |
) | |
return relevant_qna | |