gupta-amulya commited on
Commit
20df6e4
·
1 Parent(s): 39324d3

Enhance SemanticSearcher integration and refine UpvotePredictor output handling

Browse files
Files changed (3) hide show
  1. app.py +7 -3
  2. src/semantic_searcher.py +3 -1
  3. src/upvote_predictor.py +2 -6
app.py CHANGED
@@ -5,6 +5,7 @@ import pandas as pd
5
  from datasets import load_dataset
6
 
7
  from src.genai import GenAI
 
8
  from src.upvote_predictor import UpvotePredictor
9
 
10
  # Load the dataset
@@ -52,18 +53,21 @@ examples = "\n".join(
52
  # Initialize the SemanticSearcher
53
  genai = GenAI()
54
  upvote_predictor = UpvotePredictor("src/bert_model")
55
- # _ = SemanticSearcher(df_counsel_chat_topic)
56
 
57
 
58
  def get_output(question: str, question_context: str = None) -> str:
59
  answer, topic = genai.generate_content(
60
  question, question_context, unique_topics, examples
61
  )
62
- # return (answer, topic, "Yes", pd.DataFrame())
63
  upvote_prediction = upvote_predictor.get_upvote_prediction(
64
  question, answer, question_context
65
  )
66
- return (answer, topic, upvote_prediction[0], upvote_prediction[1])
 
 
 
 
67
 
68
 
69
  demo = gr.Interface(
 
5
  from datasets import load_dataset
6
 
7
  from src.genai import GenAI
8
+ from src.semantic_searcher import SemanticSearcher
9
  from src.upvote_predictor import UpvotePredictor
10
 
11
  # Load the dataset
 
53
  # Initialize the SemanticSearcher
54
  genai = GenAI()
55
  upvote_predictor = UpvotePredictor("src/bert_model")
56
+ ss = SemanticSearcher(df_counsel_chat_topic, df_counsel_chat)
57
 
58
 
59
  def get_output(question: str, question_context: str = None) -> str:
60
  answer, topic = genai.generate_content(
61
  question, question_context, unique_topics, examples
62
  )
 
63
  upvote_prediction = upvote_predictor.get_upvote_prediction(
64
  question, answer, question_context
65
  )
66
+ if "not" in upvote_prediction.lower():
67
+ df = ss.retrieve_relevant_qna(question, question_context)
68
+ return (answer, topic, upvote_prediction, df)
69
+ else:
70
+ return (answer, topic, upvote_prediction, pd.DataFrame())
71
 
72
 
73
  demo = gr.Interface(
src/semantic_searcher.py CHANGED
@@ -3,14 +3,16 @@ from sentence_transformers import SentenceTransformer
3
 
4
 
5
  class SemanticSearcher:
6
- def __init__(self, df_counsel_chat_topic):
7
  self.df_counsel_chat_topic = df_counsel_chat_topic
 
8
  self.embedder = SentenceTransformer("all-MiniLM-L6-v2")
9
  self.question_embeddings = self.embedder.encode(
10
  self.df_counsel_chat_topic["questionCombined"].tolist(),
11
  show_progress_bar=True,
12
  convert_to_tensor=True,
13
  )
 
14
  def retrieve_relevant_qna(self, question: str, question_context: str = None):
15
  if question_context is None:
16
  question_context = ""
 
3
 
4
 
5
  class SemanticSearcher:
6
+ def __init__(self, df_counsel_chat_topic, df_counsel_chat):
7
  self.df_counsel_chat_topic = df_counsel_chat_topic
8
+ self.df_counsel_chat = df_counsel_chat
9
  self.embedder = SentenceTransformer("all-MiniLM-L6-v2")
10
  self.question_embeddings = self.embedder.encode(
11
  self.df_counsel_chat_topic["questionCombined"].tolist(),
12
  show_progress_bar=True,
13
  convert_to_tensor=True,
14
  )
15
+
16
  def retrieve_relevant_qna(self, question: str, question_context: str = None):
17
  if question_context is None:
18
  question_context = ""
src/upvote_predictor.py CHANGED
@@ -1,5 +1,4 @@
1
  import numpy as np
2
- import pandas as pd
3
  import torch
4
  from torch.utils.data import DataLoader, SequentialSampler, TensorDataset
5
  from transformers import BertTokenizer
@@ -61,9 +60,6 @@ class UpvotePredictor:
61
  predictions.extend(list(pred_flat))
62
 
63
  if predictions[0] == 0:
64
- return (
65
- "Not credible suggestion",
66
- pd.DataFrame(),
67
- )
68
  else:
69
- return ("Credible suggestion", pd.DataFrame())
 
1
  import numpy as np
 
2
  import torch
3
  from torch.utils.data import DataLoader, SequentialSampler, TensorDataset
4
  from transformers import BertTokenizer
 
60
  predictions.extend(list(pred_flat))
61
 
62
  if predictions[0] == 0:
63
+ return "Not credible suggestion"
 
 
 
64
  else:
65
+ return "Credible suggestion"