File size: 3,696 Bytes
79c0556
 
 
 
 
 
 
ba08d17
79c0556
 
08cc5dc
 
79c0556
6719fe3
9d07cf2
6719fe3
9d07cf2
045955f
e64f669
fa26b1c
 
e64f669
fa26b1c
013647d
fa26b1c
 
224e382
79c0556
 
1f1886f
79c0556
 
 
 
 
317e9ce
79c0556
 
 
045955f
 
79c0556
 
 
 
 
 
 
 
 
 
 
231f795
 
79c0556
217bda1
231f795
 
 
 
 
1f1886f
f41f752
224e382
 
231f795
224e382
231f795
 
224e382
231f795
ba08d17
231f795
 
ba08d17
231f795
8bc4ecf
231f795
 
0e29da6
076ab23
 
 
0e29da6
231f795
 
ba08d17
 
 
9b22ed0
045955f
 
 
9b22ed0
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import pandas as pd
import openai
import faiss
import numpy as np
import time
import os
import pickle
import gradio as gr
from langchain.embeddings.openai import OpenAIEmbeddings
from io import StringIO
from huggingface_hub import hf_hub_download
from huggingface_hub import login

openai.api_key = os.getenv("OPENAI_API_KEY")
hf_token = os.getenv("HF_TOKEN")

login(token=hf_token)

def load_embeddings_and_faiss():
    embeddings_path = hf_hub_download(repo_id="chukbert/embedding-faq-medquad", filename="embeddings.pkl",repo_type="dataset", token=hf_token)
    faiss_index_path = hf_hub_download(repo_id="chukbert/embedding-faq-medquad", filename="faiss.index",repo_type="dataset", token=hf_token)

    faiss_index = faiss.read_index(faiss_index_path)
    
    with open(embeddings_path, 'rb') as f:
        question_embeddings = pickle.load(f)

    return faiss_index, question_embeddings

def retrieve_answer(question, faiss_index, embedding_model, answers, log_output, threshold=0.2):
    question_embedding = embedding_model.embed_query(question)
    distances, indices = faiss_index.search(np.array([question_embedding]), k=1)

    closest_distance = distances[0][0]
    closest_index = indices[0][0]
    log_output.write(f"closest_distance: {closest_distance}")

    if closest_distance > threshold:
        return "No good match found in dataset. Using GPT-4o-mini to generate an answer."

    return answers[closest_index]

def ask_openai_gpt4(question):
    response = openai.chat.completions.create(
        messages=[
            {"role": "user", "content": f"Answer the following medical question: {question}"}
        ],
        model="gpt-4o-mini",
        max_tokens=150
    )
    return response.choices[0].message.content

def chatbot(user_input):
    log_output = StringIO()  # To capture logs

    faiss_index, question_embeddings = load_embeddings_and_faiss()
    embedding_model = OpenAIEmbeddings(openai_api_key=openai.api_key)
    
    start_time = time.time()  # Start timer
    
    log_output.write("Retrieving answer from FAISS...\n")
    response_text = retrieve_answer(user_input, faiss_index, embedding_model, answers, log_output, threshold=0.3)
    
    if response_text == "No good match found in dataset. Using GPT-4o-mini to generate an answer.":
        log_output.write("No good match found in dataset. Using GPT-4o-mini to generate an answer.\n")
        response_text = ask_openai_gpt4(user_input)

    end_time = time.time()  # End timer
    response_time = end_time - start_time  # Calculate response time

    # Log the final response time

    # Return the chatbot response, response time, and log
    return response_text, f"Response time: {response_time:.4f} seconds", log_output.getvalue()

# Simplified Gradio interface with response, response time, and logs
demo = gr.Interface(
    fn=chatbot,  # Main chatbot function
    inputs="text",  # User input: single text field
    outputs=[
        gr.Textbox(label="Chatbot Response"),   # Named output for the chatbot response
        gr.Textbox(label="Response Time"),      # Named output for the response time
        gr.Textbox(label="Logs")  # Logs
    ],
    title="Medical Chatbot with Custom Knowledge About Medical FAQ",
    description="A chatbot with custom knowledge using FAISS for quick responses or fallback to GPT-4o-mini when no relevant answer is found. Response time is also tracked."
)

if __name__ == "__main__":
    # Load dataset
    df = pd.read_csv("medquad.csv")
    questions = df['question'].tolist()
    answers = df['answer'].tolist()

    print(f"Loaded questions and answers. Number of questions: {len(questions)}, Number of answers: {len(answers)}")
    demo.launch()