File size: 3,697 Bytes
8a2e2aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66ae10b
8a2e2aa
66ae10b
8a2e2aa
 
 
b3f461b
 
8a2e2aa
b3f461b
c56334c
 
 
 
b3f461b
 
c56334c
8a2e2aa
 
 
57e67e8
b3f461b
 
 
0a11e3d
b3f461b
8a2e2aa
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import pandas as pd
import os
import numpy as np
from openai import OpenAI
from data import data as df

class HadithSearch:
    def __init__(self, api_key):
        self.client = OpenAI(api_key=api_key)
        self.data = df

    def _cosine_similarity(self, a, b):
        return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))

    def _get_embedding(self, text, model="text-embedding-ada-002"):
        try:
            text = text.replace("\n", " ")
        except Exception as e:
            pass
        response = self.client.chat.completions.create(
                      model="gpt-3.5-turbo",
                      messages=[
                        {
                          "role": "system",
                          "content": "Your task is to transform a described situation into a list of the top 3 most important things to look for in a database of Islamic hadith that could be helpful to bring answers. \n\nIt should be very specific and formatted with only the list and remove all occurences of the word 'Hadiths', just the topics sought. JSON FORMAT!\n\nThe goal is to use this list to perform cosine similarity embedding search on the hadith database."
                        },
                        {
                          "role": "user",
                            "content": text
                        }
                      ],
                      temperature=1,
                      max_tokens=684,
                      top_p=1,
                      frequency_penalty=0,
                      presence_penalty=0
                    ).choices[0].message.content
        return self.client.embeddings.create(input=f"{response}", model=model).data[0].embedding

    def search_hadiths(self, user_input, num_hadiths=10):
        if self.data is None:
            raise ValueError("Data not loaded.")
        
        embedding_column_name = "embeding"
        try:
            self.data[embedding_column_name] = self.data.embeding.apply(lambda x: x["embeding"])
        except Exception as e:
            pass

        user_embedding = self._get_embedding(user_input, model='text-embedding-ada-002')
        self.data['similarities'] = self.data.embeding.apply(lambda x: self._cosine_similarity(x, user_embedding))

        results = self.data.sort_values('similarities', ascending=False).head(int(num_hadiths)).copy()
        try:
            results.drop(columns=["id","hadith_id", "embeding"], inplace=True)
        except:
            pass
        print(f"Number of hadiths to display: {num_hadiths}")
        print(f"Shape of df: {str(results.shape)}")
        formatted_results = self._format_results(results.to_dict(orient="records"))
        return formatted_results

    def _format_results(self, results):
        formatted_output = ""
        for result in results:
            formatted_output += "### Source: " + str(result["source"]) + " | Chapter name : "+ str(result["chapter"]) +" | Chapter number: " + str(result["chapter_no"]) + " | Hadith number : " +  str(result["chapter_no"]) + "\n\n"
            formatted_output += "Similarity with query: " + str(round(result["similarities"]*100,2)) + "%" +" | Chain index: " + str(result["chain_indx"]) + "\n\n"
            formatted_output += "### Hadith content:" + "\n\n" + str(result["text_en"]).replace("                          ", "") + "\n\n"
            formatted_output += "Arabic version: \n\n" + str(result["text_ar"])
            formatted_output += "\n\n-----------------------------------------------------------------------------------------------------\n\n"            
        formatted_output = formatted_output.replace("`", "")
        return  formatted_output