Spaces:
Sleeping
Sleeping
Commit
·
20df6e4
1
Parent(s):
39324d3
Enhance SemanticSearcher integration and refine UpvotePredictor output handling
Browse files- app.py +7 -3
- src/semantic_searcher.py +3 -1
- src/upvote_predictor.py +2 -6
app.py
CHANGED
@@ -5,6 +5,7 @@ import pandas as pd
|
|
5 |
from datasets import load_dataset
|
6 |
|
7 |
from src.genai import GenAI
|
|
|
8 |
from src.upvote_predictor import UpvotePredictor
|
9 |
|
10 |
# Load the dataset
|
@@ -52,18 +53,21 @@ examples = "\n".join(
|
|
52 |
# Initialize the SemanticSearcher
|
53 |
genai = GenAI()
|
54 |
upvote_predictor = UpvotePredictor("src/bert_model")
|
55 |
-
|
56 |
|
57 |
|
58 |
def get_output(question: str, question_context: str = None) -> str:
|
59 |
answer, topic = genai.generate_content(
|
60 |
question, question_context, unique_topics, examples
|
61 |
)
|
62 |
-
# return (answer, topic, "Yes", pd.DataFrame())
|
63 |
upvote_prediction = upvote_predictor.get_upvote_prediction(
|
64 |
question, answer, question_context
|
65 |
)
|
66 |
-
|
|
|
|
|
|
|
|
|
67 |
|
68 |
|
69 |
demo = gr.Interface(
|
|
|
5 |
from datasets import load_dataset
|
6 |
|
7 |
from src.genai import GenAI
|
8 |
+
from src.semantic_searcher import SemanticSearcher
|
9 |
from src.upvote_predictor import UpvotePredictor
|
10 |
|
11 |
# Load the dataset
|
|
|
53 |
# Initialize the SemanticSearcher
|
54 |
genai = GenAI()
|
55 |
upvote_predictor = UpvotePredictor("src/bert_model")
|
56 |
+
ss = SemanticSearcher(df_counsel_chat_topic, df_counsel_chat)
|
57 |
|
58 |
|
59 |
def get_output(question: str, question_context: str = None) -> str:
|
60 |
answer, topic = genai.generate_content(
|
61 |
question, question_context, unique_topics, examples
|
62 |
)
|
|
|
63 |
upvote_prediction = upvote_predictor.get_upvote_prediction(
|
64 |
question, answer, question_context
|
65 |
)
|
66 |
+
if "not" in upvote_prediction.lower():
|
67 |
+
df = ss.retrieve_relevant_qna(question, question_context)
|
68 |
+
return (answer, topic, upvote_prediction, df)
|
69 |
+
else:
|
70 |
+
return (answer, topic, upvote_prediction, pd.DataFrame())
|
71 |
|
72 |
|
73 |
demo = gr.Interface(
|
src/semantic_searcher.py
CHANGED
@@ -3,14 +3,16 @@ from sentence_transformers import SentenceTransformer
|
|
3 |
|
4 |
|
5 |
class SemanticSearcher:
|
6 |
-
def __init__(self, df_counsel_chat_topic):
|
7 |
self.df_counsel_chat_topic = df_counsel_chat_topic
|
|
|
8 |
self.embedder = SentenceTransformer("all-MiniLM-L6-v2")
|
9 |
self.question_embeddings = self.embedder.encode(
|
10 |
self.df_counsel_chat_topic["questionCombined"].tolist(),
|
11 |
show_progress_bar=True,
|
12 |
convert_to_tensor=True,
|
13 |
)
|
|
|
14 |
def retrieve_relevant_qna(self, question: str, question_context: str = None):
|
15 |
if question_context is None:
|
16 |
question_context = ""
|
|
|
3 |
|
4 |
|
5 |
class SemanticSearcher:
|
6 |
+
def __init__(self, df_counsel_chat_topic, df_counsel_chat):
|
7 |
self.df_counsel_chat_topic = df_counsel_chat_topic
|
8 |
+
self.df_counsel_chat = df_counsel_chat
|
9 |
self.embedder = SentenceTransformer("all-MiniLM-L6-v2")
|
10 |
self.question_embeddings = self.embedder.encode(
|
11 |
self.df_counsel_chat_topic["questionCombined"].tolist(),
|
12 |
show_progress_bar=True,
|
13 |
convert_to_tensor=True,
|
14 |
)
|
15 |
+
|
16 |
def retrieve_relevant_qna(self, question: str, question_context: str = None):
|
17 |
if question_context is None:
|
18 |
question_context = ""
|
src/upvote_predictor.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
import numpy as np
|
2 |
-
import pandas as pd
|
3 |
import torch
|
4 |
from torch.utils.data import DataLoader, SequentialSampler, TensorDataset
|
5 |
from transformers import BertTokenizer
|
@@ -61,9 +60,6 @@ class UpvotePredictor:
|
|
61 |
predictions.extend(list(pred_flat))
|
62 |
|
63 |
if predictions[0] == 0:
|
64 |
-
return
|
65 |
-
"Not credible suggestion",
|
66 |
-
pd.DataFrame(),
|
67 |
-
)
|
68 |
else:
|
69 |
-
return
|
|
|
1 |
import numpy as np
|
|
|
2 |
import torch
|
3 |
from torch.utils.data import DataLoader, SequentialSampler, TensorDataset
|
4 |
from transformers import BertTokenizer
|
|
|
60 |
predictions.extend(list(pred_flat))
|
61 |
|
62 |
if predictions[0] == 0:
|
63 |
+
return "Not credible suggestion"
|
|
|
|
|
|
|
64 |
else:
|
65 |
+
return "Credible suggestion"
|