Spaces:
Build error
Build error
File size: 2,947 Bytes
7d6d701 7d2deb5 7d6d701 99bbf81 4fb4308 580858f e02bd6d a627434 7d6d701 fd30064 d693fc5 08b6d98 86d2f65 ebcdcac 044c0a3 86d2f65 044c0a3 ebcdcac 044c0a3 1e517cc acf522c 26b6a5b ddfaa69 12d440a 1e517cc 1283168 9102fcd 99bbf81 d693fc5 53d588f bf1b617 99bbf81 93003ed ddfaa69 d693fc5 db5f00f 99bbf81 93003ed ddfaa69 1283168 7eac7c9 99bbf81 ddfaa69 1a8b52b ddfaa69 1283168 12d440a 4fb4308 c2e6078 37ab520 043b829 99bbf81 08b6d98 4fb4308 8d60a3f 7d6d701 4fb4308 7d6d701 bb79bf1 99bbf81 b7d5b27 908ded3 fd30064 4fb4308 a4da0c1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import gradio as gr
import os, time
from dotenv import load_dotenv, find_dotenv
from rag import llm_chain, rag_chain
from trace import wandb_trace
_ = load_dotenv(find_dotenv())
RAG_OFF = "Off"
RAG_CHROMA = "Chroma"
RAG_MONGODB = "MongoDB"
config = {
"chunk_overlap": 150,
"chunk_size": 1500,
"k": 3,
"model_name": "gpt-4-0613",
"temperature": 0,
}
def invoke(openai_api_key, rag_option, prompt):
if (openai_api_key == ""):
raise gr.Error("OpenAI API Key is required.")
if (rag_option is None):
raise gr.Error("Retrieval Augmented Generation is required.")
if (prompt == ""):
raise gr.Error("Prompt is required.")
chain = None
completion = ""
result = ""
generation_info = ""
llm_output = ""
err_msg = ""
try:
start_time_ms = round(time.time() * 1000)
if (rag_option == RAG_CHROMA):
#splits = document_loading_splitting()
#document_storage_chroma(splits)
completion, chain = rag_chain(openai_api_key, prompt)
result = completion["result"]
elif (rag_option == RAG_MONGODB):
#splits = document_loading_splitting()
#document_storage_mongodb(splits)
completion, chain = rag_chain(openai_api_key, prompt)
result = completion["result"]
else:
completion, chain = llm_chain(openai_api_key, prompt)
if (completion.generations[0] != None and completion.generations[0][0] != None):
result = completion.generations[0][0].text
generation_info = completion.generations[0][0].generation_info
llm_output = completion.llm_output
except Exception as e:
err_msg = e
raise gr.Error(e)
finally:
end_time_ms = round(time.time() * 1000)
wandb_trace(config,
rag_option == RAG_OFF,
prompt,
completion,
result,
generation_info,
llm_output,
chain,
err_msg,
start_time_ms,
end_time_ms)
return result
gr.close_all()
demo = gr.Interface(fn=invoke,
inputs = [gr.Textbox(label = "OpenAI API Key", type = "password", lines = 1),
gr.Radio([RAG_OFF, RAG_CHROMA, RAG_MONGODB], label = "Retrieval Augmented Generation", value = RAG_OFF),
gr.Textbox(label = "Prompt", value = "What are GPT-4's media capabilities in 5 emojis and 1 sentence?", lines = 1),
],
outputs = [gr.Textbox(label = "Completion", lines = 1)],
title = "Generative AI - LLM & RAG",
description = os.environ["DESCRIPTION"])
demo.launch() |