|
import asyncio |
|
from typing import List, Dict |
|
import faiss |
|
import numpy as np |
|
import pandas as pd |
|
from sqlalchemy.ext.asyncio import AsyncSession |
|
from starlette.websockets import WebSocket |
|
|
|
from project.bot.models import MessagePair |
|
from project.config import settings |
|
|
|
|
|
class SearchBot: |
|
chat_history = [] |
|
|
|
|
|
|
|
def __init__(self, memory=None): |
|
if memory is None: |
|
memory = [] |
|
self.chat_history = memory |
|
|
|
async def _summarize_user_intent(self, user_query: str) -> str: |
|
chat_history_str = '' |
|
chat_history = self.chat_history[-self.unknown_counter * 2:] |
|
for i in chat_history: |
|
if i['role'] == 'user': |
|
chat_history_str += f"{i['role']}: {i['content']}\n" |
|
messages = [ |
|
{ |
|
'role': 'system', |
|
'content': f"{settings.SUMMARIZE_PROMPT}\n" |
|
f"Chat history: ```{chat_history_str}```\n" |
|
f"User query: ```{user_query}```" |
|
} |
|
] |
|
response = await settings.OPENAI_CLIENT.chat.completions.create( |
|
messages=messages, |
|
temperature=0.1, |
|
n=1, |
|
model="gpt-3.5-turbo-0125" |
|
) |
|
user_intent = response.choices[0].message.content |
|
return user_intent |
|
|
|
@staticmethod |
|
def _cls_pooling(model_output): |
|
return model_output.last_hidden_state[:, 0] |
|
|
|
async def _convert_to_embeddings(self, text_list): |
|
encoded_input = settings.INFO_TOKENIZER( |
|
text_list, padding=True, truncation=True, return_tensors="pt" |
|
) |
|
encoded_input = {k: v.to(settings.device) for k, v in encoded_input.items()} |
|
model_output = settings.INFO_MODEL(**encoded_input) |
|
return self._cls_pooling(model_output).cpu().detach().numpy().astype('float32') |
|
|
|
@staticmethod |
|
async def _get_context_data(user_query: list[float]) -> list[dict]: |
|
radius = 30 |
|
_, distances, indices = settings.FAISS_INDEX.range_search(user_query, radius) |
|
indices_distances_df = pd.DataFrame({'index': indices, 'distance': distances}) |
|
filtered_data_df = settings.products_dataset.iloc[indices].copy() |
|
filtered_data_df.loc[:, 'distance'] = indices_distances_df['distance'].values |
|
sorted_data_df: pd.DataFrame = filtered_data_df.sort_values(by='distance').reset_index(drop=True) |
|
sorted_data_df = sorted_data_df.drop('distance', axis=1) |
|
data = sorted_data_df.head(3).to_dict(orient='records') |
|
return data |
|
|
|
@staticmethod |
|
async def create_context_str(context: List[Dict]) -> str: |
|
context_str = '' |
|
for i, chunk in enumerate(context): |
|
context_str += f'{i + 1}) {chunk["chunks"]}' |
|
return context_str |
|
|
|
async def _rag(self, context: List[Dict], query: str, session: AsyncSession, country: str): |
|
if context: |
|
context_str = await self.create_context_str(context) |
|
assistant_message = {"role": 'assistant', "content": context_str} |
|
self.chat_history.append(assistant_message) |
|
content = settings.PROMPT |
|
else: |
|
content = settings.EMPTY_PROMPT |
|
user_message = {"role": 'user', "content": query} |
|
|
|
self.chat_history.append(user_message) |
|
messages = [ |
|
{ |
|
'role': 'system', |
|
'content': content |
|
}, |
|
] |
|
messages = messages + self.chat_history |
|
|
|
stream = await settings.OPENAI_CLIENT.chat.completions.create( |
|
messages=messages, |
|
temperature=0.1, |
|
n=1, |
|
model="gpt-3.5-turbo", |
|
stream=True |
|
) |
|
response = '' |
|
async for chunk in stream: |
|
if chunk.choices[0].delta.content is not None: |
|
chunk_content = chunk.choices[0].delta.content |
|
response += chunk_content |
|
yield response |
|
await asyncio.sleep(0.02) |
|
assistant_message = {"role": 'assistant', "content": response} |
|
self.chat_history.append(assistant_message) |
|
try: |
|
session.add(MessagePair(user_message=query, bot_response=response, country=country)) |
|
except Exception as e: |
|
print(e) |
|
|
|
async def ask_and_send(self, data: Dict, websocket: WebSocket, session: AsyncSession): |
|
query = data['query'] |
|
country = data['country'] |
|
transformed_query = await self._convert_to_embeddings(query) |
|
context = await self._get_context_data(transformed_query) |
|
try: |
|
async for chunk in self._rag(context, query, session, country): |
|
await websocket.send_text(chunk) |
|
|
|
except Exception: |
|
await self.emergency_db_saving(session) |
|
|
|
@staticmethod |
|
async def emergency_db_saving(session: AsyncSession): |
|
await session.commit() |
|
await session.close() |
|
|