Poke-Bowl-AI / project /bot /chatbot.py
brestok's picture
Update project/bot/chatbot.py
f28557b verified
import asyncio
import base64
import os
import tempfile
import numpy as np
from project.config import settings
import pandas as pd
class ChatBot:
chat_history = []
def __init__(self, memory=None):
self.chat_history.append({
"role": 'assistant',
'content': "Hi! What would you like to order from the food?"
})
@staticmethod
def _transform_bytes_to_file(data_bytes) -> str:
audio_bytes = base64.b64decode(data_bytes)
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp3')
try:
temp_file.write(audio_bytes)
filepath = temp_file.name
finally:
temp_file.close()
return filepath
@staticmethod
async def _transcript_audio(temp_filepath: str) -> str:
with open(temp_filepath, 'rb') as file:
transcript = await settings.OPENAI_CLIENT.audio.transcriptions.create(
model='whisper-1',
file=file,
language="nl"
)
text = transcript.text
return text
@staticmethod
async def _convert_to_embeddings(query: str):
response = await settings.OPENAI_CLIENT.embeddings.create(
input=query,
model='text-embedding-3-large'
)
embeddings = response.data[0].embedding
return embeddings
@staticmethod
async def _convert_response_to_voice(ai_response: str) -> str:
audio = await settings.OPENAI_CLIENT.audio.speech.create(
model="tts-1",
voice="alloy",
input=ai_response
)
encoded_audio = base64.b64encode(audio.content).decode('utf-8')
return encoded_audio
@staticmethod
async def _get_context_data(query: list[float]) -> str:
query = np.array([query]).astype('float32')
_, distances, indices = settings.FAISS_INDEX.range_search(query.astype('float32'), settings.SEARCH_RADIUS)
indices_distances_df = pd.DataFrame({'index': indices, 'distance': distances})
filtered_data_df = settings.products_dataset.iloc[indices]
filtered_data_df['distance'] = indices_distances_df['distance'].values
sorted_data_df: pd.DataFrame = filtered_data_df.sort_values(by='distance').reset_index(drop=True)
sorted_data_df = sorted_data_df.drop('distance', axis=1)
data = sorted_data_df.head(1).to_dict(orient='records')
context_str = ''
for row in data:
context_str += f'{row["Search"]}\n\n'
return context_str
async def _rag(self, query: str, query_type: str, context: str = None):
if context:
self.chat_history.append({'role': 'assistant', 'content': context})
prompt = settings.PRODUCT_PROMPT
else:
if 'search' in query_type.lower():
prompt = settings.EMPTY_PRODUCT_PROMPT
elif 'purchase' in query_type.lower():
prompt = settings.ADD_TO_CART_PROMPT
elif 'product_list' in query_type.lower():
prompt = settings.PRODUCT_LIST_PROMPT
else:
prompt = settings.EMPTY_PRODUCT_PROMPT
self.chat_history.append({
'role': 'user',
'content': query
})
messages = [
{
'role': 'system',
'content': f"{prompt}"
},
]
messages += self.chat_history
completion = await settings.OPENAI_CLIENT.chat.completions.create(
messages=messages,
temperature=0,
n=1,
model="gpt-3.5-turbo",
)
response = completion.choices[0].message.content
self.chat_history.append({'role': 'assistant', 'content': response})
return response
async def _get_query_type(self, query: str) -> str:
assistant_message = self.chat_history[-1]['content']
messages = [
{
"role": 'system',
'content': settings.ANALYZER_PROMPT
},
{
"role": 'user',
"content": f"Assistant message: {assistant_message}\n"
f"User response: {query}"
}
]
completion = await settings.OPENAI_CLIENT.chat.completions.create(
messages=messages,
temperature=0,
n=1,
model="gpt-3.5-turbo",
)
response = completion.choices[0].message.content
return response
async def ask(self, data: dict):
audio = data['audio']
temp_filepath = self._transform_bytes_to_file(audio)
transcript = await self._transcript_audio(temp_filepath)
query_type = await self._get_query_type(transcript)
context = None
if query_type == 'search':
transformed_query = await self._convert_to_embeddings(transcript)
context = await self._get_context_data(transformed_query)
ai_response = await self._rag(transcript, query_type, context)
voice_ai_response = await self._convert_response_to_voice(ai_response)
data = {
'user_query': transcript,
'ai_response': ai_response,
'voice_response': voice_ai_response
}
try:
os.remove(temp_filepath)
except FileNotFoundError:
pass
return data