File size: 4,421 Bytes
79c0556
 
 
 
 
 
 
ba08d17
79c0556
 
 
6719fe3
 
79c0556
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba08d17
79c0556
 
 
 
ba08d17
79c0556
 
ba08d17
79c0556
ba08d17
79c0556
 
 
ba08d17
79c0556
 
 
ba08d17
79c0556
 
ba08d17
 
79c0556
ba08d17
79c0556
ba08d17
 
 
 
79c0556
ba08d17
79c0556
 
ba08d17
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import pandas as pd
import openai
import faiss
import numpy as np
import time
import os
import pickle
import gradio as gr
from langchain.embeddings.openai import OpenAIEmbeddings
from io import StringIO

openai.api_key = os.getenv("OPENAI_API_KEY")

def create_and_save_faiss_index(questions, embedding_model, index_file, embedding_file):
    question_embeddings = embedding_model.embed_documents(questions)
    faiss_index = faiss.IndexFlatL2(len(question_embeddings[0]))
    faiss_index.add(np.array(question_embeddings))

    faiss.write_index(faiss_index, index_file)
    with open(embedding_file, 'wb') as f:
        pickle.dump(question_embeddings, f)

    return faiss_index, question_embeddings

def load_faiss_index(index_file, embedding_file):
    faiss_index = faiss.read_index(index_file)
    with open(embedding_file, 'rb') as f:
        question_embeddings = pickle.load(f)
    return faiss_index, question_embeddings

def retrieve_answer(question, faiss_index, embedding_model, answers, threshold=0.8):
    question_embedding = embedding_model.embed_query(question)
    distances, indices = faiss_index.search(np.array([question_embedding]), k=1)

    closest_distance = distances[0][0]
    closest_index = indices[0][0]
    print(f"closest_distance: {closest_distance}")

    if closest_distance > threshold:
        return "No good match found in dataset. Using GPT-4o-mini to generate an answer."
    else:
        return answers[closest_index]

def ask_openai_gpt4(question):
    response = openai.chat.completions.create(
        messages=[
            {"role": "user", "content": f"Answer the following medical question: {question}"}
        ],
        model="gpt-4o-mini",
        max_tokens=150
    )
    return response.choices[0].message.content

def respond(message, history, system_message, max_tokens, temperature, top_p):
    log_output = StringIO()

    start_time = time.time()

    if os.path.exists('faiss.index') and os.path.exists('embeddings.pkl'):
        log_output.write("Loading FAISS index from disk...\n")
        faiss_index, question_embeddings = load_faiss_index('faiss.index', 'embeddings.pkl')
    else:
        log_output.write("Creating and saving FAISS index...\n")
        df = pd.read_csv("medquad.csv")
        questions = df['question'].tolist()
        answers = df['answer'].tolist()
        embedding_model = OpenAIEmbeddings(openai_api_key=openai.api_key)
        faiss_index, question_embeddings = create_and_save_faiss_index(questions, embedding_model, 'faiss.index', 'embeddings.pkl')

    messages = [{"role": "system", "content": system_message}]
    for user_message, bot_response in history:
        messages.append({"role": "user", "content": user_message})
        if bot_response:
            messages.append({"role": "assistant", "content": bot_response})

    user_message = message
    messages.append({"role": "user", "content": user_message})

    response_text = retrieve_answer(user_message, faiss_index, OpenAIEmbeddings(openai_api_key=openai.api_key), answers=["..."], threshold=0.8)

    if response_text == "No good match found in dataset. Using GPT-4o-mini to generate an answer.":
        log_output.write("No good match found in dataset. Using GPT-4o-mini to generate an answer.\n")
        response_text = ask_openai_gpt4(user_message)

    # Stop the timer and calculate response time
    end_time = time.time()
    response_time = end_time - start_time  # Time in seconds

    # Yield the response with the logs and response time
    yield response_text, f"Response time: {response_time:.4f} seconds", log_output.getvalue()


# Gradio ChatInterface with additional inputs for model settings and response time
demo = gr.ChatInterface(
    fn=respond,
    additional_inputs=[
        gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")
    ],
    title="Medical Chatbot with Customizable Parameters and Response Time",
    description="A chatbot with customizable parameters using FAISS for quick responses or fallback to GPT-4 when no relevant answer is found. Response time is also tracked."
)

if __name__ == "__main__":
    demo.launch()