|
import asyncio |
|
import json |
|
import re |
|
from typing import List, Dict |
|
import faiss |
|
import httpx |
|
import numpy as np |
|
import pandas as pd |
|
from sqlalchemy.ext.asyncio import AsyncSession |
|
from starlette.websockets import WebSocket |
|
from transformers import pipeline |
|
|
|
from project.bot.models import MessagePair |
|
from project.config import settings |
|
|
|
|
|
class SearchBot: |
|
chat_history = [] |
|
|
|
|
|
|
|
|
|
def __init__(self, memory=None): |
|
if memory is None: |
|
memory = [] |
|
self.chat_history = memory |
|
|
|
@staticmethod |
|
def _cls_pooling(model_output): |
|
return model_output.last_hidden_state[:, 0] |
|
|
|
@staticmethod |
|
async def enrich_information_from_google(search_word: str) -> str: |
|
url = "https://places.googleapis.com/v1/places:searchText" |
|
headers = { |
|
"Content-Type": "application/json", |
|
"X-Goog-Api-Key": settings.GOOGLE_PLACES_API_KEY, |
|
"X-Goog-FieldMask": "places.shortFormattedAddress,places.websiteUri,places.internationalPhoneNumber," |
|
"places.googleMapsUri,places.photos" |
|
} |
|
data = { |
|
"textQuery": f"{search_word} in Javea", |
|
"languageCode": "nl", |
|
"maxResultCount": 1, |
|
|
|
} |
|
async with httpx.AsyncClient() as client: |
|
response = await client.post(url, headers=headers, content=json.dumps(data)) |
|
place_response = response.json() |
|
place_response = place_response['places'][0] |
|
photo_name = place_response.get('photos') |
|
photo_uri = None |
|
if photo_name: |
|
async with httpx.AsyncClient() as client: |
|
response = await client.get( |
|
f'https://places.googleapis.com/v1/{photo_name[0]["name"]}/media?maxWidthPx=350&key={settings.GOOGLE_PLACES_API_KEY}') |
|
photo_response = response.json() |
|
photo_uri = photo_response.get('photoUri') |
|
google_maps_uri = place_response.get('googleMapsUri') |
|
phone_number = place_response.get('internationalPhoneNumber') |
|
formatted_address = place_response.get('shortFormattedAddress') |
|
website_uri = place_response.get('websiteUri') |
|
if not google_maps_uri: |
|
return search_word |
|
enriched_word = f'<a class="extraDataLink" href="{google_maps_uri}" target="_blank">{search_word}</a><div class="tooltip-elem">' |
|
if photo_uri: |
|
enriched_word += f'<img src="{photo_uri}" alt="Image" class="tooltip-img">' |
|
if formatted_address: |
|
enriched_word += f'<p><a href="{google_maps_uri}" target="_blank">{formatted_address}</a></p>' |
|
if website_uri: |
|
enriched_word += f'<p><a href="{website_uri}">Google Maps URI</a></p>' |
|
if phone_number: |
|
phone_str = re.sub(r' ', '', phone_number) |
|
enriched_word += f'<p><a href="tel:{phone_str}">Phone number</a></p>' |
|
enriched_word += f"</div>" |
|
return enriched_word |
|
|
|
async def analyze_full_response(self) -> str: |
|
assistant_message = self.chat_history.pop()['content'] |
|
nlp = pipeline("ner", model=settings.NLP_MODEL, tokenizer=settings.NLP_TOKENIZER, aggregation_strategy="simple") |
|
ner_result = nlp(assistant_message) |
|
analyzed_assistant_message = assistant_message |
|
for entity in ner_result: |
|
if entity['entity_group'] in ("LOC", "ORG", "MISC") and entity['word'] != "Javea": |
|
enriched_information = await self.enrich_information_from_google(entity['word']) |
|
analyzed_assistant_message = analyzed_assistant_message.replace(entity['word'], enriched_information, 1) |
|
return "ENRICHED:" + analyzed_assistant_message |
|
|
|
async def _convert_to_embeddings(self, text_list): |
|
encoded_input = settings.INFO_TOKENIZER( |
|
text_list, padding=True, truncation=True, return_tensors="pt" |
|
) |
|
encoded_input = {k: v.to(settings.device) for k, v in encoded_input.items()} |
|
model_output = settings.INFO_MODEL(**encoded_input) |
|
return self._cls_pooling(model_output).cpu().detach().numpy().astype('float32') |
|
|
|
@staticmethod |
|
async def _get_context_data(user_query: list[float]) -> list[dict]: |
|
radius = 5 |
|
_, distances, indices = settings.FAISS_INDEX.range_search(user_query, radius) |
|
indices_distances_df = pd.DataFrame({'index': indices, 'distance': distances}) |
|
filtered_data_df = settings.products_dataset.iloc[indices].copy() |
|
filtered_data_df.loc[:, 'distance'] = indices_distances_df['distance'].values |
|
sorted_data_df: pd.DataFrame = filtered_data_df.sort_values(by='distance').reset_index(drop=True) |
|
sorted_data_df = sorted_data_df.drop('distance', axis=1) |
|
data = sorted_data_df.head(3).to_dict(orient='records') |
|
cleaned_data = [] |
|
for chunk in data: |
|
if "Comments:" in chunk['chunks']: |
|
cleaned_data.append(chunk) |
|
return cleaned_data |
|
|
|
@staticmethod |
|
async def create_context_str(context: List[Dict]) -> str: |
|
context_str = '' |
|
for i, chunk in enumerate(context): |
|
context_str += f'{i + 1}) {chunk["chunks"]}' |
|
return context_str |
|
|
|
async def _rag(self, context: List[Dict], query: str, session: AsyncSession, country: str): |
|
if context: |
|
context_str = await self.create_context_str(context) |
|
assistant_message = {"role": 'assistant', "content": context_str} |
|
self.chat_history.append(assistant_message) |
|
content = settings.PROMPT |
|
else: |
|
content = settings.EMPTY_PROMPT |
|
user_message = {"role": 'user', "content": query} |
|
self.chat_history.append(user_message) |
|
messages = [ |
|
{ |
|
'role': 'system', |
|
'content': content |
|
}, |
|
] |
|
messages = messages + self.chat_history |
|
|
|
stream = await settings.OPENAI_CLIENT.chat.completions.create( |
|
messages=messages, |
|
temperature=0.1, |
|
n=1, |
|
model="gpt-3.5-turbo", |
|
stream=True |
|
) |
|
response = '' |
|
async for chunk in stream: |
|
if chunk.choices[0].delta.content is not None: |
|
chunk_content = chunk.choices[0].delta.content |
|
response += chunk_content |
|
yield response |
|
await asyncio.sleep(0.02) |
|
assistant_message = {"role": 'assistant', "content": response} |
|
self.chat_history.append(assistant_message) |
|
try: |
|
session.add(MessagePair(user_message=query, bot_response=response, country=country)) |
|
except Exception as e: |
|
print(e) |
|
|
|
async def ask_and_send(self, data: Dict, websocket: WebSocket, session: AsyncSession): |
|
query = data['query'] |
|
country = data['country'] |
|
transformed_query = await self._convert_to_embeddings(query) |
|
context = await self._get_context_data(transformed_query) |
|
try: |
|
async for chunk in self._rag(context, query, session, country): |
|
await websocket.send_text(chunk) |
|
analyzing = await self.analyze_full_response() |
|
await websocket.send_text(analyzing) |
|
except Exception: |
|
await self.emergency_db_saving(session) |
|
|
|
@staticmethod |
|
async def emergency_db_saving(session: AsyncSession): |
|
await session.commit() |
|
await session.close() |
|
|