Spaces:
Running
Running
File size: 4,720 Bytes
a94de35 0aad6fd e115381 790fcbd 0aad6fd e115381 0aad6fd 85c189a 0aad6fd e115381 0aad6fd e115381 0aad6fd a94de35 58a8659 e115381 bd5c379 58a8659 bd5c379 58a8659 a94de35 e115381 a94de35 0aad6fd e115381 8d6c903 4dcc069 e48e44f 4dcc069 950465b e115381 bd5c379 d2bb19e 577cbf8 a94de35 e115381 a94de35 e115381 950465b e115381 a94de35 e115381 a94de35 e115381 a94de35 e115381 a94de35 e115381 a94de35 e115381 a94de35 e115381 341437d e115381 a94de35 341437d a94de35 341437d a94de35 341437d a94de35 e115381 a94de35 e115381 a94de35 e115381 a94de35 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import streamlit as st
from teapotai import TeapotAI, TeapotAISettings
import hashlib
import os
import requests
import time
from langsmith import traceable
def log_time(func):
def wrapper(*args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
print(f"{func.__name__} executed in {end_time - start_time:.4f} seconds")
return result
return wrapper
default_documents = []
API_KEY = os.environ.get("brave_api_key")
@log_time
def brave_search(query, count=3):
url = "https://api.search.brave.com/res/v1/web/search"
headers = {"Accept": "application/json", "X-Subscription-Token": API_KEY}
params = {"q": query, "count": count}
response = requests.get(url, headers=headers, params=params)
if response.status_code == 200:
results = response.json().get("web", {}).get("results", [])
print(results)
return [(res["title"], res["description"], res["url"]) for res in results]
else:
print(f"Error: {response.status_code}, {response.text}")
return []
@traceable
@log_time
def query_teapot(prompt, context, user_input, teapot_ai):
response = teapot_ai.query(
context=prompt+"\n"+context,
query=user_input
)
return response
@log_time
def handle_chat(user_input, teapot_ai):
results = brave_search(user_input)
documents = [desc.replace('<strong>','').replace('</strong>','') for _, desc, _ in results]
st.sidebar.write("---")
st.sidebar.write("## RAG Documents")
for (title, description, url) in results:
# Display Results
st.sidebar.write(f"## {title}")
st.sidebar.write(f"{description.replace('<strong>','').replace('</strong>','')}")
st.sidebar.write(f"[Source]({url})")
st.sidebar.write("---")
context = "\n".join(documents)
prompt = "You are Teapot, an open-source AI assistant optimized for low-end devices, providing short, accurate responses without hallucinating while excelling at information extraction and text summarization."
response = query_teapot(prompt, context, user_input, teapot_ai)
return response
def suggestion_button(suggestion_text, teapot_ai):
if st.button(suggestion_text):
handle_chat(suggestion_text, teapot_ai)
@log_time
def hash_documents(documents):
return hashlib.sha256("\n".join(documents).encode("utf-8")).hexdigest()
def main():
st.set_page_config(page_title="TeapotAI Chat", page_icon=":robot_face:", layout="wide")
st.sidebar.header("Retrieval Augmented Generation")
user_documents = st.sidebar.text_area("Enter documents, each on a new line", value="\n".join(default_documents))
documents = [doc.strip() for doc in user_documents.split("\n") if doc.strip()]
new_documents_hash = hash_documents(documents)
if "documents_hash" not in st.session_state or st.session_state.documents_hash != new_documents_hash:
with st.spinner('Loading Model and Embeddings...'):
start_time = time.time()
teapot_ai = TeapotAI(documents=documents or default_documents, settings=TeapotAISettings(rag_num_results=3))
end_time = time.time()
print(f"Model loaded in {end_time - start_time:.4f} seconds")
st.session_state.documents_hash = new_documents_hash
st.session_state.teapot_ai = teapot_ai
else:
teapot_ai = st.session_state.teapot_ai
if "messages" not in st.session_state:
st.session_state.messages = [{"role": "assistant", "content": "Hi, I am Teapot AI, how can I help you?"}]
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
user_input = st.chat_input("Ask me anything")
s1, s2, s3 = st.columns([1, 2, 3])
with s1:
suggestion_button("Tell me about the varieties of tea", teapot_ai)
with s2:
suggestion_button("Who was born first, Alan Turing or John von Neumann?", teapot_ai)
with s3:
suggestion_button("Extract Google's stock price", teapot_ai)
if user_input:
with st.chat_message("user"):
st.markdown(user_input)
st.session_state.messages.append({"role": "user", "content": user_input})
with st.spinner('Generating Response...'):
response = handle_chat(user_input, teapot_ai)
with st.chat_message("assistant"):
st.markdown(response)
st.session_state.messages.append({"role": "assistant", "content": response})
st.markdown("### Suggested Questions")
if __name__ == "__main__":
main()
|