USC-GPT / app.py
bhulston's picture
Update app.py
c37fc7c
raw
history blame
3.55 kB
import streamlit as st
from datetime import time as t
import time
from operator import itemgetter
import os
import json
import getpass
import openai
from openai import OpenAi
from langchain.vectorstores import Pinecone
from langchain.embeddings import OpenAIEmbeddings
import pinecone
from results import results_agent
from filter import filter_agent
from reranker import reranker
from utils import build_filter
from router import routing_agent
OPENAI_API = st.secrets["OPENAI_API"]
PINECONE_API = st.secrets["PINECONE_API"]
openai.api_key = OPENAI_API
pinecone.init(
api_key= PINECONE_API,
environment="gcp-starter"
)
index_name = "use-class-db"
embeddings = OpenAIEmbeddings(openai_api_key = OPENAI_API)
index = pinecone.Index(index_name)
k = 5
st.title("USC GPT - Find the perfect class")
class_time = st.slider(
"Filter Class Times:",
value=(t(11, 30), t(12, 45)))
# st.write("You're scheduled for:", class_time)
units = st.slider(
"Number of units",
1, 4,
value = (1, 4)
)
assistant = st.chat_message("assistant")
initial_message = "How can I help you today?"
def get_rag_results(prompt):
'''
1. Remove filters from the prompt to optimize success of the RAG-based step.
2. Query the Pinecone DB and return the top 25 results based on cosine similarity
3. Rerank the results from vector DB using a BERT-based cross encoder
'''
query = prompt
response = filter_agent(prompt, OPENAI_API)
response = index.query(
vector = embeddings.embed_query(query),
top_k = 25,
include_metadata = True
)
response = reranker(query, response) # BERT cross encoder for ranking
return response
if "messages" not in st.session_state:
st.session_state.messages = []
with st.chat_message("assistant"):
st.markdown(initial_message)
st.session_state.messages.append({"role": "assistant", "content": initial_message})
if prompt := st.chat_input("What kind of class are you looking for?"):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
with st.chat_message("assistant"):
message_placeholder = st.empty()
full_response = ""
messages = [{"role": m["role"], "content": m["content"]}
for m in st.session_state.messages]
message_history = " ".join([message["content"] for message in messages])
route = routing_agent(prompt, OPENAI_API, message_history)
if route == "1":
## Option for accessing Vector DB
rag_response = get_rag_results(prompt)
result_query = 'Original Query:' + prompt + 'Query Results:' + str(rag_response)
assistant_response = results_agent(result_query, OPENAI_API)
else:
## Option if not accessing Database
assistant_response = openai.chatCompletion.create(
model = "gpt-4",
messages = [
{"role": m["role"], "content": m["content"]}
for m in st.session_state.messages
]
)["choices"][0]["message"]["content"]
## Display response regardless of route
for chunk in assistant_response.split():
full_response += chunk + " "
time.sleep(0.05)
message_placeholder.markdown(full_response + "β–Œ")
message_placeholder.markdown(full_response)
st.session_state.messages.append({"role": "assistant", "content": full_response})