brestok commited on
Commit
3d560e1
·
verified ·
1 Parent(s): 498d1ce

Update project/bot/openai_backend.py

Browse files
Files changed (1) hide show
  1. project/bot/openai_backend.py +5 -4
project/bot/openai_backend.py CHANGED
@@ -77,13 +77,13 @@ class SearchBot:
77
 
78
  async def analyze_full_response(self) -> str:
79
  assistant_message = self.chat_history.pop()['content']
80
- nlp = pipeline("ner", model=settings.NLP_MODEL, tokenizer=settings.NLP_TOKENIZER, grouped_entities=True)
81
  ner_result = nlp(assistant_message)
82
  analyzed_assistant_message = assistant_message
83
  for entity in ner_result:
84
  if entity['entity_group'] in ("LOC", "ORG", "MISC") and entity['word'] != "Javea":
85
  enriched_information = await self.enrich_information_from_google(entity['word'])
86
- analyzed_assistant_message = analyzed_assistant_message.replace(entity['word'], enriched_information)
87
  return "ENRICHED:" + analyzed_assistant_message
88
 
89
  async def _convert_to_embeddings(self, text_list):
@@ -96,7 +96,7 @@ class SearchBot:
96
 
97
  @staticmethod
98
  async def _get_context_data(user_query: list[float]) -> list[dict]:
99
- radius = 4
100
  _, distances, indices = settings.FAISS_INDEX.range_search(user_query, radius)
101
  indices_distances_df = pd.DataFrame({'index': indices, 'distance': distances})
102
  filtered_data_df = settings.products_dataset.iloc[indices].copy()
@@ -110,7 +110,8 @@ class SearchBot:
110
  async def create_context_str(context: List[Dict]) -> str:
111
  context_str = ''
112
  for i, chunk in enumerate(context):
113
- context_str += f'{i + 1}) {chunk["chunks"]}'
 
114
  return context_str
115
 
116
  async def _rag(self, context: List[Dict], query: str, session: AsyncSession, country: str):
 
77
 
78
  async def analyze_full_response(self) -> str:
79
  assistant_message = self.chat_history.pop()['content']
80
+ nlp = pipeline("ner", model=settings.NLP_MODEL, tokenizer=settings.NLP_TOKENIZER, aggregation_strategy="simple")
81
  ner_result = nlp(assistant_message)
82
  analyzed_assistant_message = assistant_message
83
  for entity in ner_result:
84
  if entity['entity_group'] in ("LOC", "ORG", "MISC") and entity['word'] != "Javea":
85
  enriched_information = await self.enrich_information_from_google(entity['word'])
86
+ analyzed_assistant_message = analyzed_assistant_message.replace(entity['word'], enriched_information, 1)
87
  return "ENRICHED:" + analyzed_assistant_message
88
 
89
  async def _convert_to_embeddings(self, text_list):
 
96
 
97
  @staticmethod
98
  async def _get_context_data(user_query: list[float]) -> list[dict]:
99
+ radius = 5
100
  _, distances, indices = settings.FAISS_INDEX.range_search(user_query, radius)
101
  indices_distances_df = pd.DataFrame({'index': indices, 'distance': distances})
102
  filtered_data_df = settings.products_dataset.iloc[indices].copy()
 
110
  async def create_context_str(context: List[Dict]) -> str:
111
  context_str = ''
112
  for i, chunk in enumerate(context):
113
+ if "Comments:" in chunk['chunks']:
114
+ context_str += f'{i + 1}) {chunk["chunks"]}'
115
  return context_str
116
 
117
  async def _rag(self, context: List[Dict], query: str, session: AsyncSession, country: str):