Spaces:
Running
on
T4
Running
on
T4
import streamlit as st | |
from teapotai import TeapotAI, TeapotAISettings | |
import hashlib | |
import os | |
import requests | |
import time | |
from langsmith import traceable | |
def log_time(func): | |
def wrapper(*args, **kwargs): | |
start_time = time.time() | |
result = func(*args, **kwargs) | |
end_time = time.time() | |
print(f"{func.__name__} executed in {end_time - start_time:.4f} seconds") | |
return result | |
return wrapper | |
default_documents = [] | |
API_KEY = os.environ.get("brave_api_key") | |
def brave_search(query, count=3): | |
url = "https://api.search.brave.com/res/v1/web/search" | |
headers = {"Accept": "application/json", "X-Subscription-Token": API_KEY} | |
params = {"q": query, "count": count} | |
response = requests.get(url, headers=headers, params=params) | |
if response.status_code == 200: | |
results = response.json().get("web", {}).get("results", []) | |
print(results) | |
return [(res["title"], res["description"], res["url"]) for res in results] | |
else: | |
print(f"Error: {response.status_code}, {response.text}") | |
return [] | |
def query_teapot(prompt, context, user_input, teapot_ai): | |
response = teapot_ai.query( | |
context=prompt+"\n"+context, | |
query=user_input | |
) | |
return response | |
def handle_chat(user_input, teapot_ai): | |
results = brave_search(user_input) | |
documents = [desc.replace('<strong>','').replace('</strong>','') for _, desc, _ in results] | |
st.sidebar.write("---") | |
st.sidebar.write("## RAG Documents") | |
for (title, description, url) in results: | |
# Display Results | |
st.sidebar.write(f"## {title}") | |
st.sidebar.write(f"{description.replace('<strong>','').replace('</strong>','')}") | |
st.sidebar.write(f"[Source]({url})") | |
st.sidebar.write("---") | |
context = "\n".join(documents) | |
prompt = "You are Teapot, an open-source AI assistant optimized for low-end devices, providing short, accurate responses without hallucinating while excelling at information extraction and text summarization." | |
response = query_teapot(prompt, context, user_input, teapot_ai) | |
return response | |
def suggestion_button(suggestion_text, teapot_ai): | |
if st.button(suggestion_text): | |
handle_chat(suggestion_text, teapot_ai) | |
def hash_documents(documents): | |
return hashlib.sha256("\n".join(documents).encode("utf-8")).hexdigest() | |
def main(): | |
st.set_page_config(page_title="TeapotAI Chat", page_icon=":robot_face:", layout="wide") | |
st.sidebar.header("Retrieval Augmented Generation") | |
user_documents = st.sidebar.text_area("Enter documents, each on a new line", value="\n".join(default_documents)) | |
documents = [doc.strip() for doc in user_documents.split("\n") if doc.strip()] | |
new_documents_hash = hash_documents(documents) | |
if "documents_hash" not in st.session_state or st.session_state.documents_hash != new_documents_hash: | |
with st.spinner('Loading Model and Embeddings...'): | |
start_time = time.time() | |
teapot_ai = TeapotAI(documents=documents or default_documents, settings=TeapotAISettings(rag_num_results=3)) | |
end_time = time.time() | |
print(f"Model loaded in {end_time - start_time:.4f} seconds") | |
st.session_state.documents_hash = new_documents_hash | |
st.session_state.teapot_ai = teapot_ai | |
else: | |
teapot_ai = st.session_state.teapot_ai | |
if "messages" not in st.session_state: | |
st.session_state.messages = [{"role": "assistant", "content": "Hi, I am Teapot AI, how can I help you?"}] | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
user_input = st.chat_input("Ask me anything") | |
s1, s2, s3 = st.columns([1, 2, 3]) | |
with s1: | |
suggestion_button("Tell me about the varieties of tea", teapot_ai) | |
with s2: | |
suggestion_button("Who was born first, Alan Turing or John von Neumann?", teapot_ai) | |
with s3: | |
suggestion_button("Extract Google's stock price", teapot_ai) | |
if user_input: | |
with st.chat_message("user"): | |
st.markdown(user_input) | |
st.session_state.messages.append({"role": "user", "content": user_input}) | |
with st.spinner('Generating Response...'): | |
response = handle_chat(user_input, teapot_ai) | |
with st.chat_message("assistant"): | |
st.markdown(response) | |
st.session_state.messages.append({"role": "assistant", "content": response}) | |
st.markdown("### Suggested Questions") | |
if __name__ == "__main__": | |
main() | |