import asyncio |
import json |
import re |
from typing import List, Dict |
import faiss |
import httpx |
import numpy as np |
import pandas as pd |
from sqlalchemy.ext.asyncio import AsyncSession |
from starlette.websockets import WebSocket |
from transformers import pipeline |
from project.bot.models import MessagePair |
from project.config import settings |
class SearchBot: |
chat_history = [] |
def __init__(self, memory=None): |
if memory is None: |
memory = [] |
self.chat_history = memory |
@staticmethod |
def _cls_pooling(model_output): |
return model_output.last_hidden_state[:, 0] |
@staticmethod |
async def enrich_information_from_google(search_word: str) -> str: |
url = "https://places.googleapis.com/v1/places:searchText" |
headers = { |
"Content-Type": "application/json", |
"X-Goog-Api-Key": settings.GOOGLE_PLACES_API_KEY, |
"X-Goog-FieldMask": "places.shortFormattedAddress,places.websiteUri,places.internationalPhoneNumber," |
"places.googleMapsUri,places.photos" |
} |
data = { |
"textQuery": f"{search_word} in Javea", |
"languageCode": "nl", |
"maxResultCount": 1, |
} |
async with httpx.AsyncClient() as client: |
response = await client.post(url, headers=headers, content=json.dumps(data)) |
place_response = response.json() |
place_response = place_response['places'][0] |
photo_name = place_response.get('photos') |
photo_uri = None |
if photo_name: |
async with httpx.AsyncClient() as client: |
response = await client.get( |
f'https://places.googleapis.com/v1/{photo_name[0]["name"]}/media?maxWidthPx=350&key={settings.GOOGLE_PLACES_API_KEY}') |
photo_response = response.json() |
photo_uri = photo_response.get('photoUri') |
google_maps_uri = place_response.get('googleMapsUri') |
phone_number = place_response.get('internationalPhoneNumber') |
formatted_address = place_response.get('shortFormattedAddress') |
website_uri = place_response.get('websiteUri') |
if not google_maps_uri: |
return search_word |
enriched_word = f'<a class="extraDataLink" href="{google_maps_uri}" target="_blank">{search_word}</a><div class="tooltip-elem">' |
if photo_uri: |
enriched_word += f'<img src="{photo_uri}" alt="Image" class="tooltip-img">' |
if formatted_address: |
enriched_word += f'<p><a href="{google_maps_uri}" target="_blank">{formatted_address}</a></p>' |
if website_uri: |
enriched_word += f'<p><a href="{website_uri}">Google Maps URI</a></p>' |
if phone_number: |
phone_str = re.sub(r' ', '', phone_number) |
enriched_word += f'<p><a href="tel:{phone_str}">Phone number</a></p>' |
enriched_word += f"</div>" |
return enriched_word |
async def analyze_full_response(self) -> str: |
assistant_message = self.chat_history.pop()['content'] |
nlp = pipeline("ner", model=settings.NLP_MODEL, tokenizer=settings.NLP_TOKENIZER, aggregation_strategy="simple") |
ner_result = nlp(assistant_message) |
analyzed_assistant_message = assistant_message |
for entity in ner_result: |
if entity['entity_group'] in ("LOC", "ORG", "MISC") and entity['word'] != "Javea": |
enriched_information = await self.enrich_information_from_google(entity['word']) |
analyzed_assistant_message = analyzed_assistant_message.replace(entity['word'], enriched_information, 1) |
return "ENRICHED:" + analyzed_assistant_message |
async def _convert_to_embeddings(self, text_list): |
encoded_input = settings.INFO_TOKENIZER( |
text_list, padding=True, truncation=True, return_tensors="pt" |
) |
encoded_input = {k: v.to(settings.device) for k, v in encoded_input.items()} |
model_output = settings.INFO_MODEL(**encoded_input) |
return self._cls_pooling(model_output).cpu().detach().numpy().astype('float32') |
@staticmethod |
async def _get_context_data(user_query: list[float]) -> list[dict]: |
radius = 5 |
_, distances, indices = settings.FAISS_INDEX.range_search(user_query, radius) |
indices_distances_df = pd.DataFrame({'index': indices, 'distance': distances}) |
filtered_data_df = settings.products_dataset.iloc[indices].copy() |
filtered_data_df.loc[:, 'distance'] = indices_distances_df['distance'].values |
sorted_data_df: pd.DataFrame = filtered_data_df.sort_values(by='distance').reset_index(drop=True) |
sorted_data_df = sorted_data_df.drop('distance', axis=1) |
data = sorted_data_df.head(3).to_dict(orient='records') |
cleaned_data = [] |
for chunk in data: |
if "Comments:" in chunk['chunks']: |
cleaned_data.append(chunk) |
return cleaned_data |
@staticmethod |
async def create_context_str(context: List[Dict]) -> str: |
context_str = '' |
for i, chunk in enumerate(context): |
context_str += f'{i + 1}) {chunk["chunks"]}' |
return context_str |
async def _rag(self, context: List[Dict], query: str, session: AsyncSession, country: str): |
if context: |
context_str = await self.create_context_str(context) |
assistant_message = {"role": 'assistant', "content": context_str} |
self.chat_history.append(assistant_message) |
content = settings.PROMPT |
else: |
content = settings.EMPTY_PROMPT |
user_message = {"role": 'user', "content": query} |
self.chat_history.append(user_message) |
messages = [ |
{ |
'role': 'system', |
'content': content |
}, |
] |
messages = messages + self.chat_history |
stream = await settings.OPENAI_CLIENT.chat.completions.create( |
messages=messages, |
temperature=0.1, |
n=1, |
model="gpt-3.5-turbo", |
stream=True |
) |
response = '' |
async for chunk in stream: |
if chunk.choices[0].delta.content is not None: |
chunk_content = chunk.choices[0].delta.content |
response += chunk_content |
yield response |
await asyncio.sleep(0.02) |
assistant_message = {"role": 'assistant', "content": response} |
self.chat_history.append(assistant_message) |
try: |
session.add(MessagePair(user_message=query, bot_response=response, country=country)) |
except Exception as e: |
print(e) |
async def ask_and_send(self, data: Dict, websocket: WebSocket, session: AsyncSession): |
query = data['query'] |
country = data['country'] |
transformed_query = await self._convert_to_embeddings(query) |
context = await self._get_context_data(transformed_query) |
try: |
async for chunk in self._rag(context, query, session, country): |
await websocket.send_text(chunk) |
analyzing = await self.analyze_full_response() |
await websocket.send_text(analyzing) |
except Exception: |
await self.emergency_db_saving(session) |
@staticmethod |
async def emergency_db_saving(session: AsyncSession): |
await session.commit() |
await session.close() |