|
import os |
|
import time |
|
import openai |
|
import gradio as gr |
|
import polars as pl |
|
from sentence_transformers import SentenceTransformer |
|
from langchain.vectorstores.azuresearch import AzureSearch |
|
|
|
|
|
|
|
from dotenv import load_dotenv |
|
|
|
load_dotenv() |
|
|
|
openai.api_type = "azure" |
|
openai.api_version = "2023-03-15-preview" |
|
openai.api_key = os.getenv("OPENAI_API_KEY") |
|
openai.api_base = os.getenv("OPENAI_API_BASE") |
|
vector_store_address = os.getenv("VECTOR_STORE_URL") |
|
vector_store_password = os.getenv("VECTOR_STORE_KEY") |
|
index_name = "motor-gm-search" |
|
|
|
df = pl.read_csv("year-make-model.csv") |
|
|
|
years = df["year"].unique().to_list() |
|
makes = df["make"].unique().to_list() |
|
models = df["model"].unique().to_list() |
|
|
|
with open("sys_prompt.txt", "r") as f: |
|
sys_prompt = f.read() |
|
|
|
with open("translate_prompt.txt", "r") as f: |
|
translate_prompt = f.read() |
|
|
|
|
|
embedder = SentenceTransformer("BAAI/bge-small-en") |
|
vector_store = AzureSearch( |
|
azure_search_endpoint=vector_store_address, |
|
azure_search_key=vector_store_password, |
|
index_name=index_name, |
|
embedding_function=lambda x: embedder.encode([x])[0], |
|
) |
|
|
|
|
|
def filter_makes(year): |
|
df1 = df.filter(pl.col("year") == int(year)) |
|
choices = sorted(df1["make"].unique().to_list()) |
|
return gr.Dropdown.update(choices=choices, interactive=True) |
|
|
|
|
|
def filter_models(year, make): |
|
df1 = df.filter(pl.col("year") == int(year)) |
|
df1 = df1.filter(pl.col("make") == make) |
|
choices = sorted(df1["model"].unique().to_list()) |
|
return gr.Dropdown.update(choices=choices, interactive=True) |
|
|
|
|
|
def gpt(history, prompt, temp=0.0, stream=True): |
|
hist = [{"role": "system", "content": prompt}] |
|
for user, bot in history: |
|
hist += [{"role": "user", "content": user}] |
|
if bot: |
|
hist += [{"role": "assistant", "content": bot}] |
|
return openai.ChatCompletion.create( |
|
deployment_id="gpt-35-turbo-16k", |
|
messages=hist, |
|
temperature=temp, |
|
stream=stream, |
|
) |
|
|
|
|
|
def user(message, history): |
|
|
|
return "", history + [[message, None]] |
|
|
|
|
|
def search(history, results, year, make, model): |
|
if results: |
|
|
|
return history, results |
|
|
|
query = gpt(history, translate_prompt, stream=False)["choices"][0]["message"][ |
|
"content" |
|
] |
|
print(query) |
|
|
|
filters = f"year eq {year} and make eq '{make}' and model eq '{model}'" |
|
res = vector_store.similarity_search( |
|
query, 5, search_type="hybrid", filters=filters |
|
) |
|
results = [] |
|
for r in res: |
|
results.append( |
|
{ |
|
"title": r.metadata["title"], |
|
"content": r.page_content, |
|
} |
|
) |
|
return history, results |
|
|
|
|
|
def bot(history, results): |
|
res = gpt(history, sys_prompt + str(results)) |
|
history[-1][1] = "" |
|
for chunk in res: |
|
if "content" in chunk["choices"][0]["delta"]: |
|
history[-1][1] = history[-1][1] + chunk["choices"][0]["delta"]["content"] |
|
yield history |
|
|
|
|
|
with gr.Blocks( |
|
css="footer {visibility: hidden} #docs {height: 600px; overflow: auto !important}" |
|
) as app: |
|
with gr.Row(): |
|
year = gr.Dropdown(years, label="Year") |
|
make = gr.Dropdown([], label="Make", interactive=False) |
|
model = gr.Dropdown([], label="Model", interactive=False) |
|
year.change(filter_makes, year, make) |
|
make.change(filter_models, [year, make], model) |
|
with gr.Row(): |
|
with gr.Column(scale=0.3333): |
|
results = [] |
|
text = gr.JSON(None, language="json", interactive=False, elem_id="docs") |
|
with gr.Column(scale=0.6667): |
|
chatbot = gr.Chatbot(height=462) |
|
with gr.Row(): |
|
msg = gr.Textbox(show_label=False, scale=7) |
|
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( |
|
search, |
|
[chatbot, text, year, make, model], |
|
[chatbot, text], |
|
queue=False, |
|
).then(bot, [chatbot, text], chatbot) |
|
btn = gr.Button("Send", variant="primary") |
|
btn.click(user, [msg, chatbot], [msg, chatbot], queue=False).then( |
|
search, |
|
[chatbot, text, year, make, model], |
|
[chatbot, text], |
|
queue=False, |
|
).then(bot, [chatbot, text], chatbot) |
|
with gr.Row(): |
|
gr.Button("Clear").click( |
|
lambda x, y: ([], None), [chatbot, text], [chatbot, text] |
|
) |
|
gr.Button("Undo").click(lambda x: (x[:-1]), [chatbot], [chatbot]) |
|
|
|
app.queue().launch(auth=(os.getenv("USER"), os.getenv("PASSWORD"))) |
|
|
|
|