Last commit not found
import streamlit as st | |
from datetime import time as t | |
import time | |
from operator import itemgetter | |
import os | |
import json | |
import getpass | |
import openai | |
from langchain.vectorstores import Pinecone | |
from langchain.embeddings import OpenAIEmbeddings | |
import pinecone | |
from results import results_agent | |
from filter import filter_agent | |
from reranker import reranker | |
from utils import build_filter | |
from router import routing_agent | |
OPENAI_API = st.secrets["OPENAI_API"] | |
PINECONE_API = st.secrets["PINECONE_API"] | |
openai.api_key = OPENAI_API | |
pinecone.init( | |
api_key= PINECONE_API, | |
environment="gcp-starter" | |
) | |
index_name = "use-class-db" | |
embeddings = OpenAIEmbeddings(openai_api_key = OPENAI_API) | |
index = pinecone.Index(index_name) | |
k = 5 | |
st.title("USC GPT - Find the perfect class") | |
class_time = st.slider( | |
"Filter Class Times:", | |
value=(t(11, 30), t(12, 45))) | |
# st.write("You're scheduled for:", class_time) | |
units = st.slider( | |
"Number of units", | |
1, 4, | |
value = (1, 4) | |
) | |
assistant = st.chat_message("assistant") | |
initial_message = "How can I help you today?" | |
def get_rag_results(prompt): | |
''' | |
1. Remove filters from the prompt to optimize success of the RAG-based step. | |
2. Query the Pinecone DB and return the top 25 results based on cosine similarity | |
3. Rerank the results from vector DB using a BERT-based cross encoder | |
''' | |
query = prompt | |
response = filter_agent(prompt, OPENAI_API) | |
response = index.query( | |
vector = embeddings.embed_query(query), | |
top_k = 25, | |
include_metadata = True | |
) | |
response = reranker(query, response) # BERT cross encoder for ranking | |
return response | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
with st.chat_message("assistant"): | |
st.markdown(initial_message) | |
st.session_state.messages.append({"role": "assistant", "content": initial_message}) | |
if prompt := st.chat_input("What kind of class are you looking for?"): | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
with st.chat_message("user"): | |
st.markdown(prompt) | |
with st.chat_message("assistant"): | |
message_placeholder = st.empty() | |
full_response = "" | |
messages = [{"role": m["role"], "content": m["content"]} | |
for m in st.session_state.messages] | |
message_history = " ".join([message["content"] for message in messages]) | |
route = routing_agent(prompt, OPENAI_API, message_history) | |
if route == "1": | |
## Option for accessing Vector DB | |
rag_response = get_rag_results(prompt) | |
result_query = 'Original Query:' + prompt + 'Query Results:' + str(rag_response) | |
assistant_response = results_agent(result_query, OPENAI_API) | |
else: | |
## Option if not accessing Database | |
assistant_response = openai.ChatCompletion.create( | |
model = "gpt-4", | |
messages = [ | |
{"role": m["role"], "content": m["content"]} | |
for m in st.session_state.messages | |
] | |
)["choices"][0]["message"]["content"] | |
## Display response regardless of route | |
for chunk in assistant_response.split(): | |
full_response += chunk + " " | |
time.sleep(0.05) | |
message_placeholder.markdown(full_response + "β") | |
message_placeholder.markdown(full_response) | |
st.session_state.messages.append({"role": "assistant", "content": full_response}) | |