Spaces:
Sleeping
Sleeping
import pandas as pd | |
import os | |
import numpy as np | |
from openai import OpenAI | |
from data import data as df | |
class HadithSearch: | |
def __init__(self, api_key): | |
self.client = OpenAI(api_key=api_key) | |
self.data = df | |
def _cosine_similarity(self, a, b): | |
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)) | |
def _get_embedding(self, text, model="text-embedding-ada-002"): | |
try: | |
text = text.replace("\n", " ") | |
except Exception as e: | |
pass | |
response = self.client.chat.completions.create( | |
model="gpt-3.5-turbo", | |
messages=[ | |
{ | |
"role": "system", | |
"content": "Your task is to transform a described situation into a list of the top 3 most important things to look for in a database of Islamic hadith that could be helpful to bring answers. \n\nIt should be very specific and formatted with only the list and remove all occurences of the word 'Hadiths', just the topics sought. JSON FORMAT!\n\nThe goal is to use this list to perform cosine similarity embedding search on the hadith database." | |
}, | |
{ | |
"role": "user", | |
"content": text | |
} | |
], | |
temperature=1, | |
max_tokens=684, | |
top_p=1, | |
frequency_penalty=0, | |
presence_penalty=0 | |
).choices[0].message.content | |
return self.client.embeddings.create(input=f"{response}", model=model).data[0].embedding | |
def search_hadiths(self, user_input, num_hadiths=10): | |
if self.data is None: | |
raise ValueError("Data not loaded.") | |
embedding_column_name = "embeding" | |
try: | |
self.data[embedding_column_name] = self.data.embeding.apply(lambda x: x["embeding"]) | |
except Exception as e: | |
pass | |
user_embedding = self._get_embedding(user_input, model='text-embedding-ada-002') | |
self.data['similarities'] = self.data.embeding.apply(lambda x: self._cosine_similarity(x, user_embedding)) | |
results = self.data.sort_values('similarities', ascending=False).head(int(num_hadiths)).copy() | |
try: | |
results.drop(columns=["id","hadith_id", "embeding"], inplace=True) | |
except: | |
pass | |
print(f"Number of hadiths to display: {num_hadiths}") | |
print(f"Shape of df: {str(results.shape)}") | |
formatted_results = self._format_results(results.to_dict(orient="records")) | |
return formatted_results | |
def _format_results(self, results): | |
formatted_output = "" | |
for result in results: | |
formatted_output += "### Source: " + str(result["source"]) + " | Chapter name : "+ str(result["chapter"]) +" | Chapter number: " + str(result["chapter_no"]) + " | Hadith number : " + str(result["chapter_no"]) + "\n\n" | |
formatted_output += "Similarity with query: " + str(round(result["similarities"]*100,2)) + "%" +" | Chain index: " + str(result["chain_indx"]) + "\n\n" | |
formatted_output += "### Hadith content:" + "\n\n" + str(result["text_en"]).replace(" ", "") + "\n\n" | |
formatted_output += "Arabic version: \n\n" + str(result["text_ar"]) | |
formatted_output += "\n\n-----------------------------------------------------------------------------------------------------\n\n" | |
formatted_output = formatted_output.replace("`", "") | |
return formatted_output | |