File size: 2,947 Bytes
7d6d701
7d2deb5
7d6d701
99bbf81
4fb4308
580858f
e02bd6d
a627434
7d6d701
 
fd30064
 
d693fc5
 
08b6d98
 
 
 
 
 
 
 
86d2f65
ebcdcac
044c0a3
86d2f65
044c0a3
ebcdcac
044c0a3
1e517cc
acf522c
26b6a5b
ddfaa69
 
 
12d440a
1e517cc
1283168
9102fcd
99bbf81
d693fc5
53d588f
bf1b617
99bbf81
93003ed
ddfaa69
d693fc5
db5f00f
 
99bbf81
93003ed
ddfaa69
1283168
7eac7c9
99bbf81
ddfaa69
 
 
1a8b52b
ddfaa69
1283168
12d440a
4fb4308
c2e6078
37ab520
043b829
99bbf81
08b6d98
 
4fb4308
 
 
 
 
 
 
 
 
8d60a3f
7d6d701
 
4fb4308
7d6d701
bb79bf1
 
99bbf81
 
b7d5b27
908ded3
fd30064
4fb4308
a4da0c1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import gradio as gr
import os, time

from dotenv import load_dotenv, find_dotenv

from rag import llm_chain, rag_chain
from trace import wandb_trace

_ = load_dotenv(find_dotenv())

RAG_OFF     = "Off"
RAG_CHROMA  = "Chroma"
RAG_MONGODB = "MongoDB"

config = {
    "chunk_overlap": 150,
    "chunk_size": 1500,
    "k": 3,
    "model_name": "gpt-4-0613",
    "temperature": 0,
}

def invoke(openai_api_key, rag_option, prompt):
    if (openai_api_key == ""):
        raise gr.Error("OpenAI API Key is required.")
    if (rag_option is None):
        raise gr.Error("Retrieval Augmented Generation is required.")
    if (prompt == ""):
        raise gr.Error("Prompt is required.")
    
    chain = None
    completion = ""
    result = ""
    generation_info = ""
    llm_output = ""
    err_msg = ""
    
    try:
        start_time_ms = round(time.time() * 1000)

        if (rag_option == RAG_CHROMA):
            #splits = document_loading_splitting()
            #document_storage_chroma(splits)
            
            completion, chain = rag_chain(openai_api_key, prompt)
            result = completion["result"]
        elif (rag_option == RAG_MONGODB):
            #splits = document_loading_splitting()
            #document_storage_mongodb(splits)
            
            completion, chain = rag_chain(openai_api_key, prompt)
            result = completion["result"]
        else:
            completion, chain = llm_chain(openai_api_key, prompt)
            
            if (completion.generations[0] != None and completion.generations[0][0] != None):
                result = completion.generations[0][0].text
                generation_info = completion.generations[0][0].generation_info

            llm_output = completion.llm_output
    except Exception as e:
        err_msg = e

        raise gr.Error(e)
    finally:
        end_time_ms = round(time.time() * 1000)
        
        wandb_trace(config,
                    rag_option == RAG_OFF, 
                    prompt, 
                    completion, 
                    result, 
                    generation_info, 
                    llm_output, 
                    chain, 
                    err_msg, 
                    start_time_ms, 
                    end_time_ms)
    return result

gr.close_all()

demo = gr.Interface(fn=invoke, 
                    inputs = [gr.Textbox(label = "OpenAI API Key", type = "password", lines = 1), 
                              gr.Radio([RAG_OFF, RAG_CHROMA, RAG_MONGODB], label = "Retrieval Augmented Generation", value = RAG_OFF),
                              gr.Textbox(label = "Prompt", value = "What are GPT-4's media capabilities in 5 emojis and 1 sentence?", lines = 1),
                             ],
                    outputs = [gr.Textbox(label = "Completion", lines = 1)],
                    title = "Generative AI - LLM & RAG",
                    description = os.environ["DESCRIPTION"])

demo.launch()