File size: 3,313 Bytes
1049895
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
import gradio as gr
from dotenv import load_dotenv
from nemoguardrails import LLMRails, RailsConfig
from langchain_openai import ChatOpenAI
from langchain_google_genai import ChatGoogleGenerativeAI
from chain import qa_chain
from vectorstore import qdrant_client

load_dotenv()

os.environ["TOKENIZERS_PARALLELISM"] = "false"
MODEL_API_KEY = os.getenv("OPENAI_API_KEY") or os.getenv("GOOGLE_API_KEY") or ""
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL") or "http://localhost:11434/v1"

MODEL_LIST = {
    "openai": "gpt-4o-mini",
    "gemini": "gemini-1.5-pro-002"
}
DEFAULT_MODEL = "openai"

def vector_search(message):
    documents = qdrant_client.query(collection_name="aws_faq", query_text=message, limit=4)
    context = '\n'.join([doc.metadata["document"] for doc in documents])
    return context

def initialize_app(llm):
    config = RailsConfig.from_path("config")
    app = LLMRails(config=config, llm=llm)
    return app

def format_messages(message, relevant_chunks):
    messages = [{"role": "context", "content": {"relevant_chunks": relevant_chunks}}, {"role": "user", "content": message}]
    return messages

async def predict(message, _, model_api_key, provider, is_guardrails):
    if not model_api_key:
        return "OpenAI/Gemini API Key is required to run this demo, please enter your OpenAI API key in the settings and configs section!"

    if provider == "gemini":
        llm = ChatGoogleGenerativeAI(google_api_key=model_api_key, model=MODEL_LIST[provider])
    elif provider == "openai":
        llm = ChatOpenAI(openai_api_key=model_api_key, model_name=MODEL_LIST[provider])
    elif provider == "ollama":
        llm = ChatOpenAI(openai_api_key="", openai_api_base=OLLAMA_BASE_URL, model_name=MODEL_LIST[provider])
    else:
        return "Invalid provider selected, please select a valid provider from the dropdown!"

    context = vector_search(message)

    if not is_guardrails:
        return qa_chain(llm, message, context)
    
    app = initialize_app(llm)
    response = await app.generate_async(messages=format_messages(message, context))
    return response["content"]

with gr.Blocks() as demo:
    gr.HTML("""<div style='height: 10px'></div>""")
    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown(
                """
                # AWS Chatbot | Guardrails
                Experiment on langchain with NeMo Guardrails.
                """
            )
        with gr.Column(scale=2):
            with gr.Group():
                with gr.Row():
                    guardrail = gr.Checkbox(label="Guardrails", info="Enables NeMo Guardrails",value=True, scale=1)
                    provider = gr.Dropdown(MODEL_LIST.keys(), value=DEFAULT_MODEL, show_label=False, scale=1)
                    model_key = gr.Textbox(placeholder="Enter your OpenAI/Gemini API key", type="password", value=MODEL_API_KEY, show_label=False, scale=3)

    gr.ChatInterface(
        predict,
        chatbot=gr.Chatbot(height=600, type="messages", layout="panel"),
        theme="soft",
        examples=[["How reliable is Amazon S3 with data availability ?"], ["How do I get started with EC2 Capacity Blocks ?"]],
        type="messages",
        additional_inputs=[model_key, provider, guardrail]
    )

if __name__ == "__main__":
    demo.launch()