Update project/bot/openai_backend.py
Browse files
project/bot/openai_backend.py
CHANGED
@@ -77,13 +77,13 @@ class SearchBot:
|
|
77 |
|
78 |
async def analyze_full_response(self) -> str:
|
79 |
assistant_message = self.chat_history.pop()['content']
|
80 |
-
nlp = pipeline("ner", model=settings.NLP_MODEL, tokenizer=settings.NLP_TOKENIZER,
|
81 |
ner_result = nlp(assistant_message)
|
82 |
analyzed_assistant_message = assistant_message
|
83 |
for entity in ner_result:
|
84 |
if entity['entity_group'] in ("LOC", "ORG", "MISC") and entity['word'] != "Javea":
|
85 |
enriched_information = await self.enrich_information_from_google(entity['word'])
|
86 |
-
analyzed_assistant_message = analyzed_assistant_message.replace(entity['word'], enriched_information)
|
87 |
return "ENRICHED:" + analyzed_assistant_message
|
88 |
|
89 |
async def _convert_to_embeddings(self, text_list):
|
@@ -96,7 +96,7 @@ class SearchBot:
|
|
96 |
|
97 |
@staticmethod
|
98 |
async def _get_context_data(user_query: list[float]) -> list[dict]:
|
99 |
-
radius =
|
100 |
_, distances, indices = settings.FAISS_INDEX.range_search(user_query, radius)
|
101 |
indices_distances_df = pd.DataFrame({'index': indices, 'distance': distances})
|
102 |
filtered_data_df = settings.products_dataset.iloc[indices].copy()
|
@@ -110,7 +110,8 @@ class SearchBot:
|
|
110 |
async def create_context_str(context: List[Dict]) -> str:
|
111 |
context_str = ''
|
112 |
for i, chunk in enumerate(context):
|
113 |
-
|
|
|
114 |
return context_str
|
115 |
|
116 |
async def _rag(self, context: List[Dict], query: str, session: AsyncSession, country: str):
|
|
|
77 |
|
78 |
async def analyze_full_response(self) -> str:
|
79 |
assistant_message = self.chat_history.pop()['content']
|
80 |
+
nlp = pipeline("ner", model=settings.NLP_MODEL, tokenizer=settings.NLP_TOKENIZER, aggregation_strategy="simple")
|
81 |
ner_result = nlp(assistant_message)
|
82 |
analyzed_assistant_message = assistant_message
|
83 |
for entity in ner_result:
|
84 |
if entity['entity_group'] in ("LOC", "ORG", "MISC") and entity['word'] != "Javea":
|
85 |
enriched_information = await self.enrich_information_from_google(entity['word'])
|
86 |
+
analyzed_assistant_message = analyzed_assistant_message.replace(entity['word'], enriched_information, 1)
|
87 |
return "ENRICHED:" + analyzed_assistant_message
|
88 |
|
89 |
async def _convert_to_embeddings(self, text_list):
|
|
|
96 |
|
97 |
@staticmethod
|
98 |
async def _get_context_data(user_query: list[float]) -> list[dict]:
|
99 |
+
radius = 5
|
100 |
_, distances, indices = settings.FAISS_INDEX.range_search(user_query, radius)
|
101 |
indices_distances_df = pd.DataFrame({'index': indices, 'distance': distances})
|
102 |
filtered_data_df = settings.products_dataset.iloc[indices].copy()
|
|
|
110 |
async def create_context_str(context: List[Dict]) -> str:
|
111 |
context_str = ''
|
112 |
for i, chunk in enumerate(context):
|
113 |
+
if "Comments:" in chunk['chunks']:
|
114 |
+
context_str += f'{i + 1}) {chunk["chunks"]}'
|
115 |
return context_str
|
116 |
|
117 |
async def _rag(self, context: List[Dict], query: str, session: AsyncSession, country: str):
|