File size: 5,093 Bytes
c8a0f34 df1d046 7b6cdd4 c755297 61dbd5e 2587c2e c755297 e378588 34ce748 c755297 e16b761 c37fc7c c755297 d4c77e7 0a9e139 c755297 4871886 c755297 783ad43 c755297 54c11e5 307f1d8 c9d3a09 3ed1495 307f1d8 e16b761 c020cdf 6e37d5d 5dc1bc3 9f17ce8 c020cdf f573c2a 2587c2e 009017d e16b761 4e0f9dd 5dc1bc3 783ad43 5dc1bc3 c9d3a09 655400f 3ed1495 655400f 5dc1bc3 4e0f9dd 783ad43 5dc1bc3 4e0f9dd e16b761 5490950 c1e65f1 4e0f9dd e16b761 4e0f9dd 5dc1bc3 1e5c398 6164e6b 571d9c3 6164e6b 51a3672 4c8d045 7026099 93a7ed5 6164e6b 3c15656 6164e6b 4e0f9dd 845068b 4e0f9dd c1e65f1 71adbd5 571d9c3 845068b 571d9c3 50b1dd2 845068b 88ba472 9f17ce8 845068b 9f17ce8 4e0f9dd c1226e0 6164e6b 2587c2e 6164e6b 009017d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import streamlit as st
from datetime import time as t
import time
from operator import itemgetter
import os
import json
import getpass
import openai
import re
from langchain.vectorstores import Pinecone
from langchain.embeddings import OpenAIEmbeddings
import pinecone
from results import results_agent
from filter import filter_agent
from reranker import reranker
from utils import build_filter, clean_pinecone
from router import routing_agent
OPENAI_API = st.secrets["OPENAI_API"]
PINECONE_API = st.secrets["PINECONE_API"]
openai.api_key = OPENAI_API
pinecone.init(
api_key= PINECONE_API,
environment="gcp-starter"
)
index_name = "use-class-db"
embeddings = OpenAIEmbeddings(openai_api_key = OPENAI_API)
index = pinecone.Index(index_name)
k = 35
st.title("USC GPT - Find the perfect class")
class_time = st.slider(
"Filter Class Times:",
value=(t(11, 30), t(12, 45))
)
units = st.slider(
"Number of units",
1, 4, 4
)
days = st.multiselect("What days are you free?",
options = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat"],
default = None,
placeholder = "Any day"
)
assistant = st.chat_message("assistant")
initial_message = "Hello, I am your GPT-powered USC Class Helper! \n How can I assist you today?"
def get_rag_results(prompt):
'''
1. Remove filters from the prompt to optimize success of the RAG-based step.
2. Query the Pinecone DB and return the top 25 results based on cosine similarity
3. Rerank the results from vector DB using a BERT-based cross encoder
'''
query = filter_agent(prompt, OPENAI_API)
print("Here is the response from the filter_agent", query)
##Get metadata filters
days_filter = list()
start = float(class_time[0].hour) + float(class_time[0].minute) / 100.0
end = float(class_time[1].hour) + float(class_time[1].minute) / 100.0
query_filter = {
"start": {"$gte": start},
"end": {"$lte": end}
}
if units != "any":
query_filter["units"] = str(int(units)) + ".0 units"
if len(days) > 0:
for i in range(len(days)):
days_filter.append(days[i])
for j in range(i+1, len(days)):
two_day = days[i] + ", " + days[j]
days_filter.append(two_day)
query_filter["days"] = {"$in": days_filter}
## Query the pinecone database
response = index.query(
vector = embeddings.embed_query(query),
top_k = k,
filter = query_filter,
include_metadata = True
)
response, additional_metadata = clean_pinecone(response)
if len(response) < 1:
response = "No classes were found that matched your criteria"
additional_metadata = "None"
else:
response = reranker(query, response) # BERT cross encoder for ranking
return response, additional_metadata
if "messages" not in st.session_state:
st.session_state.messages = []
st.session_state.messages.append({"role": "assistant", "content": initial_message})
st.session_state.rag_responses = []
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
if prompt := st.chat_input("What kind of class are you looking for?"):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
with st.chat_message("assistant"):
message_placeholder = st.empty()
full_response = ""
messages = [{"role": m["role"], "content": m["content"]}
for m in st.session_state.messages[-6:]]
message_history = " ".join([message["content"] for message in messages])
print("Prompt is", prompt)
rag_response, additional_metadata = get_rag_results(prompt)
rag_response = " ".join([message for message in rag_response])
st.session_state.rag_responses.append(rag_response)
print("Here is the session state responses", st.session_state.rag_responses)
all_rag_responses = " ".join([response for response in st.session_state.rag_responses])
result_query = 'Original Query:' + prompt
# '\n Additional Class Times:' + str(additional_metadata)
assistant_response = results_agent(result_query, "Class Options from RAG:" + all_rag_responses + "\nMessage_history" + message_history)
# assistant_response = openai.ChatCompletion.create(
# model = "gpt-4",
# messages = [
# {"role": m["role"], "content": m["content"]}
# for m in st.session_state.messages
# ]
# )["choices"][0]["message"]["content"]
## Display response regardless of route
for chunk in re.split(r'(\s+)', assistant_response):
full_response += chunk + " "
time.sleep(0.02)
message_placeholder.markdown(full_response + "▌")
st.session_state.messages.append({"role": "assistant", "content": full_response})
|