chukbert's picture
Update app.py
6719fe3 verified
raw
history blame
4.42 kB
import pandas as pd
import openai
import faiss
import numpy as np
import time
import os
import pickle
import gradio as gr
from langchain.embeddings.openai import OpenAIEmbeddings
from io import StringIO
openai.api_key = os.getenv("OPENAI_API_KEY")
def create_and_save_faiss_index(questions, embedding_model, index_file, embedding_file):
question_embeddings = embedding_model.embed_documents(questions)
faiss_index = faiss.IndexFlatL2(len(question_embeddings[0]))
faiss_index.add(np.array(question_embeddings))
faiss.write_index(faiss_index, index_file)
with open(embedding_file, 'wb') as f:
pickle.dump(question_embeddings, f)
return faiss_index, question_embeddings
def load_faiss_index(index_file, embedding_file):
faiss_index = faiss.read_index(index_file)
with open(embedding_file, 'rb') as f:
question_embeddings = pickle.load(f)
return faiss_index, question_embeddings
def retrieve_answer(question, faiss_index, embedding_model, answers, threshold=0.8):
question_embedding = embedding_model.embed_query(question)
distances, indices = faiss_index.search(np.array([question_embedding]), k=1)
closest_distance = distances[0][0]
closest_index = indices[0][0]
print(f"closest_distance: {closest_distance}")
if closest_distance > threshold:
return "No good match found in dataset. Using GPT-4o-mini to generate an answer."
else:
return answers[closest_index]
def ask_openai_gpt4(question):
response = openai.chat.completions.create(
messages=[
{"role": "user", "content": f"Answer the following medical question: {question}"}
],
model="gpt-4o-mini",
max_tokens=150
)
return response.choices[0].message.content
def respond(message, history, system_message, max_tokens, temperature, top_p):
log_output = StringIO()
start_time = time.time()
if os.path.exists('faiss.index') and os.path.exists('embeddings.pkl'):
log_output.write("Loading FAISS index from disk...\n")
faiss_index, question_embeddings = load_faiss_index('faiss.index', 'embeddings.pkl')
else:
log_output.write("Creating and saving FAISS index...\n")
df = pd.read_csv("medquad.csv")
questions = df['question'].tolist()
answers = df['answer'].tolist()
embedding_model = OpenAIEmbeddings(openai_api_key=openai.api_key)
faiss_index, question_embeddings = create_and_save_faiss_index(questions, embedding_model, 'faiss.index', 'embeddings.pkl')
messages = [{"role": "system", "content": system_message}]
for user_message, bot_response in history:
messages.append({"role": "user", "content": user_message})
if bot_response:
messages.append({"role": "assistant", "content": bot_response})
user_message = message
messages.append({"role": "user", "content": user_message})
response_text = retrieve_answer(user_message, faiss_index, OpenAIEmbeddings(openai_api_key=openai.api_key), answers=["..."], threshold=0.8)
if response_text == "No good match found in dataset. Using GPT-4o-mini to generate an answer.":
log_output.write("No good match found in dataset. Using GPT-4o-mini to generate an answer.\n")
response_text = ask_openai_gpt4(user_message)
# Stop the timer and calculate response time
end_time = time.time()
response_time = end_time - start_time # Time in seconds
# Yield the response with the logs and response time
yield response_text, f"Response time: {response_time:.4f} seconds", log_output.getvalue()
# Gradio ChatInterface with additional inputs for model settings and response time
demo = gr.ChatInterface(
fn=respond,
additional_inputs=[
gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")
],
title="Medical Chatbot with Customizable Parameters and Response Time",
description="A chatbot with customizable parameters using FAISS for quick responses or fallback to GPT-4 when no relevant answer is found. Response time is also tracked."
)
if __name__ == "__main__":
demo.launch()