openai-llm-rag / app.py
bstraehle's picture
Update app.py
08b6d98
raw
history blame
2.95 kB
import gradio as gr
import os, time
from dotenv import load_dotenv, find_dotenv
from rag import llm_chain, rag_chain
from trace import wandb_trace
_ = load_dotenv(find_dotenv())
RAG_OFF = "Off"
RAG_CHROMA = "Chroma"
RAG_MONGODB = "MongoDB"
config = {
"chunk_overlap": 150,
"chunk_size": 1500,
"k": 3,
"model_name": "gpt-4-0613",
"temperature": 0,
}
def invoke(openai_api_key, rag_option, prompt):
if (openai_api_key == ""):
raise gr.Error("OpenAI API Key is required.")
if (rag_option is None):
raise gr.Error("Retrieval Augmented Generation is required.")
if (prompt == ""):
raise gr.Error("Prompt is required.")
chain = None
completion = ""
result = ""
generation_info = ""
llm_output = ""
err_msg = ""
try:
start_time_ms = round(time.time() * 1000)
if (rag_option == RAG_CHROMA):
#splits = document_loading_splitting()
#document_storage_chroma(splits)
completion, chain = rag_chain(openai_api_key, prompt)
result = completion["result"]
elif (rag_option == RAG_MONGODB):
#splits = document_loading_splitting()
#document_storage_mongodb(splits)
completion, chain = rag_chain(openai_api_key, prompt)
result = completion["result"]
else:
completion, chain = llm_chain(openai_api_key, prompt)
if (completion.generations[0] != None and completion.generations[0][0] != None):
result = completion.generations[0][0].text
generation_info = completion.generations[0][0].generation_info
llm_output = completion.llm_output
except Exception as e:
err_msg = e
raise gr.Error(e)
finally:
end_time_ms = round(time.time() * 1000)
wandb_trace(config,
rag_option == RAG_OFF,
prompt,
completion,
result,
generation_info,
llm_output,
chain,
err_msg,
start_time_ms,
end_time_ms)
return result
gr.close_all()
demo = gr.Interface(fn=invoke,
inputs = [gr.Textbox(label = "OpenAI API Key", type = "password", lines = 1),
gr.Radio([RAG_OFF, RAG_CHROMA, RAG_MONGODB], label = "Retrieval Augmented Generation", value = RAG_OFF),
gr.Textbox(label = "Prompt", value = "What are GPT-4's media capabilities in 5 emojis and 1 sentence?", lines = 1),
],
outputs = [gr.Textbox(label = "Completion", lines = 1)],
title = "Generative AI - LLM & RAG",
description = os.environ["DESCRIPTION"])
demo.launch()