File size: 3,554 Bytes
c8a0f34
df1d046
 
7b6cdd4
c755297
 
 
 
61dbd5e
4e0f9dd
 
c755297
 
 
 
 
 
e378588
34ce748
c755297
 
c37fc7c
c755297
d4c77e7
 
0a9e139
 
c755297
 
 
 
 
 
 
4871886
c755297
 
 
 
 
54c11e5
 
307f1d8
 
df1d046
307f1d8
 
 
 
 
 
 
c020cdf
 
 
f573c2a
3cbe532
009017d
4e0f9dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e5c398
 
6164e6b
 
 
 
4c8d045
7026099
93a7ed5
6164e6b
3c15656
6164e6b
 
 
 
4e0f9dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6164e6b
 
 
 
 
 
009017d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import streamlit as st
from datetime import time as t
import time

from operator import itemgetter  
import os
import json
import getpass
import openai

from openai import OpenAi
  
from langchain.vectorstores import Pinecone
from langchain.embeddings import OpenAIEmbeddings  
import pinecone


from results import results_agent
from filter import filter_agent
from reranker import reranker
from utils import build_filter
from router import routing_agent

OPENAI_API = st.secrets["OPENAI_API"]
PINECONE_API = st.secrets["PINECONE_API"]
openai.api_key = OPENAI_API


pinecone.init(
    api_key= PINECONE_API,
    environment="gcp-starter" 
)
index_name = "use-class-db"

embeddings = OpenAIEmbeddings(openai_api_key = OPENAI_API)

index = pinecone.Index(index_name)

k = 5

st.title("USC GPT - Find the perfect class")

class_time = st.slider(
    "Filter Class Times:",
    value=(t(11, 30), t(12, 45)))

# st.write("You're scheduled for:", class_time)

units = st.slider(
    "Number of units",
    1, 4,
    value = (1, 4)
)


assistant = st.chat_message("assistant")
initial_message = "How can I help you today?"

def get_rag_results(prompt):
    '''
1. Remove filters from the prompt to optimize success of the RAG-based step.
2. Query the Pinecone DB and return the top 25 results based on cosine similarity
3. Rerank the results from vector DB using a BERT-based cross encoder
    '''
    query = prompt
    response = filter_agent(prompt, OPENAI_API)
    response = index.query(
        vector = embeddings.embed_query(query),
        top_k = 25,
        include_metadata = True
    )
    response = reranker(query, response) # BERT cross encoder for ranking

    return response

if "messages" not in st.session_state:
    st.session_state.messages = []
    with st.chat_message("assistant"):
        st.markdown(initial_message)
    st.session_state.messages.append({"role": "assistant", "content": initial_message})
    

if prompt := st.chat_input("What kind of class are you looking for?"):
    st.session_state.messages.append({"role": "user", "content": prompt})
    with st.chat_message("user"):
            st.markdown(prompt)

    with st.chat_message("assistant"):
        message_placeholder = st.empty()
        full_response = ""

        messages = [{"role": m["role"], "content": m["content"]}
                for m in st.session_state.messages]
        message_history = " ".join([message["content"] for message in messages])

        route = routing_agent(prompt, OPENAI_API, message_history)

        if route == "1":
        ## Option for accessing Vector DB
            rag_response = get_rag_results(prompt)
            result_query = 'Original Query:' + prompt + 'Query Results:' + str(rag_response)
            assistant_response = results_agent(result_query, OPENAI_API)
        else:    
        ## Option if not accessing Database
            assistant_response = openai.chatCompletion.create(
                model = "gpt-4",
                messages = [
                    {"role": m["role"], "content": m["content"]}
                    for m in st.session_state.messages
                ]
            )["choices"][0]["message"]["content"]

        ## Display response regardless of route
        for chunk in assistant_response.split():
            full_response += chunk + " "
            time.sleep(0.05)
            message_placeholder.markdown(full_response + "▌")
        message_placeholder.markdown(full_response)
        st.session_state.messages.append({"role": "assistant", "content": full_response})