File size: 4,720 Bytes
a94de35
 
 
0aad6fd
 
e115381
790fcbd
0aad6fd
e115381
 
 
 
 
 
 
 
 
0aad6fd
 
85c189a
0aad6fd
e115381
0aad6fd
 
 
e115381
0aad6fd
 
 
 
 
 
 
 
 
 
a94de35
58a8659
e115381
bd5c379
58a8659
bd5c379
58a8659
 
 
a94de35
e115381
a94de35
0aad6fd
e115381
 
8d6c903
 
4dcc069
 
 
e48e44f
4dcc069
 
950465b
e115381
bd5c379
d2bb19e
577cbf8
a94de35
 
 
 
 
 
e115381
a94de35
 
 
 
 
e115381
950465b
e115381
a94de35
e115381
a94de35
e115381
a94de35
 
e115381
a94de35
e115381
 
a94de35
 
 
 
 
e115381
a94de35
 
 
 
 
 
e115381
341437d
e115381
a94de35
 
341437d
a94de35
341437d
a94de35
341437d
a94de35
 
 
 
e115381
a94de35
 
 
e115381
a94de35
 
e115381
a94de35
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import streamlit as st
from teapotai import TeapotAI, TeapotAISettings
import hashlib
import os
import requests
import time
from langsmith import traceable

def log_time(func):
    def wrapper(*args, **kwargs):
        start_time = time.time()
        result = func(*args, **kwargs)
        end_time = time.time()
        print(f"{func.__name__} executed in {end_time - start_time:.4f} seconds")
        return result
    return wrapper

default_documents = []

API_KEY = os.environ.get("brave_api_key")

@log_time
def brave_search(query, count=3):
    url = "https://api.search.brave.com/res/v1/web/search"
    headers = {"Accept": "application/json", "X-Subscription-Token": API_KEY}
    params = {"q": query, "count": count}
    
    response = requests.get(url, headers=headers, params=params)
    
    if response.status_code == 200:
        results = response.json().get("web", {}).get("results", [])
        print(results)
        return [(res["title"], res["description"], res["url"]) for res in results]
    else:
        print(f"Error: {response.status_code}, {response.text}")
        return []

@traceable 
@log_time
def query_teapot(prompt, context, user_input, teapot_ai):
    response = teapot_ai.query(
        context=prompt+"\n"+context,
        query=user_input
    )
    return response

@log_time
def handle_chat(user_input, teapot_ai):
    results = brave_search(user_input)
    
    documents = [desc.replace('<strong>','').replace('</strong>','') for _, desc, _ in results]
    st.sidebar.write("---")
    st.sidebar.write("## RAG Documents")
    for (title, description, url) in results:
        # Display Results 
        st.sidebar.write(f"## {title}")
        st.sidebar.write(f"{description.replace('<strong>','').replace('</strong>','')}")
        st.sidebar.write(f"[Source]({url})")
        st.sidebar.write("---")

    context = "\n".join(documents)
    prompt = "You are Teapot, an open-source AI assistant optimized for low-end devices, providing short, accurate responses without hallucinating while excelling at information extraction and text summarization."
    response = query_teapot(prompt, context, user_input, teapot_ai)
    
    return response

def suggestion_button(suggestion_text, teapot_ai):
    if st.button(suggestion_text):
        handle_chat(suggestion_text, teapot_ai)

@log_time
def hash_documents(documents):
    return hashlib.sha256("\n".join(documents).encode("utf-8")).hexdigest()

def main():
    st.set_page_config(page_title="TeapotAI Chat", page_icon=":robot_face:", layout="wide")
    
    st.sidebar.header("Retrieval Augmented Generation")
    user_documents = st.sidebar.text_area("Enter documents, each on a new line", value="\n".join(default_documents))
    
    documents = [doc.strip() for doc in user_documents.split("\n") if doc.strip()]
    new_documents_hash = hash_documents(documents)
    
    if "documents_hash" not in st.session_state or st.session_state.documents_hash != new_documents_hash:
        with st.spinner('Loading Model and Embeddings...'):
            start_time = time.time()
            teapot_ai = TeapotAI(documents=documents or default_documents, settings=TeapotAISettings(rag_num_results=3))
            end_time = time.time()
            print(f"Model loaded in {end_time - start_time:.4f} seconds")
        
        st.session_state.documents_hash = new_documents_hash
        st.session_state.teapot_ai = teapot_ai
    else:
        teapot_ai = st.session_state.teapot_ai
    
    if "messages" not in st.session_state:
        st.session_state.messages = [{"role": "assistant", "content": "Hi, I am Teapot AI, how can I help you?"}]
    
    for message in st.session_state.messages:
        with st.chat_message(message["role"]):
            st.markdown(message["content"])
    
    user_input = st.chat_input("Ask me anything")
    
    s1, s2, s3 = st.columns([1, 2, 3])
    with s1:
        suggestion_button("Tell me about the varieties of tea", teapot_ai)
    with s2:
        suggestion_button("Who was born first, Alan Turing or John von Neumann?", teapot_ai)
    with s3:
        suggestion_button("Extract Google's stock price", teapot_ai)
    
    if user_input:
        with st.chat_message("user"):
            st.markdown(user_input)
        
        st.session_state.messages.append({"role": "user", "content": user_input})
        with st.spinner('Generating Response...'):
            response = handle_chat(user_input, teapot_ai)
        
        with st.chat_message("assistant"):
            st.markdown(response)
        
        st.session_state.messages.append({"role": "assistant", "content": response})
        st.markdown("### Suggested Questions")

if __name__ == "__main__":
    main()