import streamlit as st from teapotai import TeapotAI, TeapotAISettings import hashlib import os import requests import time from langsmith import traceable ##### Begin Library Code from transformers import pipeline import torch from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import numpy as np from sklearn.metrics.pairwise import cosine_similarity from pydantic import BaseModel from typing import List, Optional from tqdm import tqdm import re import os class TeapotAISettings(BaseModel): """ Pydantic settings model for TeapotAI configuration. Attributes: use_rag (bool): Whether to use RAG (Retrieve and Generate). rag_num_results (int): Number of top documents to retrieve based on similarity. rag_similarity_threshold (float): Similarity threshold for document relevance. verbose (bool): Whether to print verbose updates. log_level (str): The log level for the application (e.g., "info", "debug"). """ use_rag: bool = True # Whether to use RAG (Retrieve and Generate) rag_num_results: int = 3 # Number of top documents to retrieve based on similarity rag_similarity_threshold: float = 0.5 # Similarity threshold for document relevance verbose: bool = True # Whether to print verbose updates log_level: str = "info" # Log level setting (e.g., 'info', 'debug') class TeapotAI: """ TeapotAI class that interacts with a language model for text generation and retrieval tasks. Attributes: model (str): The model identifier. model_revision (Optional[str]): The revision/version of the model. api_key (Optional[str]): API key for accessing the model (if required). settings (TeapotAISettings): Configuration settings for the AI instance. generator (callable): The pipeline for text generation. embedding_model (callable): The pipeline for feature extraction (document embeddings). documents (List[str]): List of documents for retrieval. document_embeddings (np.ndarray): Embeddings for the provided documents. """ def __init__(self, model_revision: Optional[str] = None, api_key: Optional[str] = None, documents: List[str] = [], settings: TeapotAISettings = TeapotAISettings()): """ Initializes the TeapotAI class with optional model_revision and api_key. Parameters: model_revision (Optional[str]): The revision/version of the model to use. api_key (Optional[str]): The API key for accessing the model if needed. documents (List[str]): A list of documents for retrieval. Defaults to an empty list. settings (TeapotAISettings): The settings configuration (defaults to TeapotAISettings()). """ self.model = "teapotai/teapotllm" self.model_revision = model_revision self.api_key = api_key self.settings = settings if self.settings.verbose: print(""" _____ _ _ ___ __o__ _;; |_ _|__ __ _ _ __ ___ | |_ / \ |_ _| __ /-___-\__/ / | |/ _ \/ _` | '_ \ / _ \| __| / _ \ | | ( | |__/ | | __/ (_| | |_) | (_) | |_ / ___ \ | | \_|~~~~~~~| |_|\___|\__,_| .__/ \___/ \__/ /_/ \_\___| \_____/ |_| """) if self.settings.verbose: print(f"Loading Model: {self.model} Revision: {self.model_revision or 'Latest'}") # self.generator = pipeline("text2text-generation", model=self.model, revision=self.model_revision) if model_revision else pipeline("text2text-generation", model=self.model) self.tokenizer = AutoTokenizer.from_pretrained(self.model) model = AutoModelForSeq2SeqLM.from_pretrained(self.model) model.eval() # Quantization settings quantization_dtype = torch.qint8 # or torch.float16 quantization_config = torch.quantization.get_default_qconfig('fbgemm') # or 'onednn' self.quantized_model = torch.quantization.quantize_dynamic( model, {torch.nn.Linear}, dtype=quantization_dtype ) self.documents = documents if self.settings.use_rag and self.documents: self.embedding_model = pipeline("feature-extraction", model="teapotai/teapotembedding") self.document_embeddings = self._generate_document_embeddings(self.documents) def _generate_document_embeddings(self, documents: List[str]) -> np.ndarray: """ Generate embeddings for the provided documents using the embedding model. Parameters: documents (List[str]): A list of document strings to generate embeddings for. Returns: np.ndarray: A NumPy array of document embeddings. """ embeddings = [] if self.settings.verbose: print("Generating embeddings for documents...") for doc in tqdm(documents, desc="Document Embedding", unit="doc"): embeddings.append(self.embedding_model(doc)[0][0]) else: for doc in documents: embeddings.append(self.embedding_model(doc)[0][0]) return np.array(embeddings) def rag(self, query: str) -> List[str]: """ Perform RAG (Retrieve and Generate) by finding the most relevant documents based on cosine similarity. Parameters: query (str): The query string to find relevant documents for. Returns: List[str]: A list of the top N most relevant documents. """ if not self.settings.use_rag or not self.documents: return [] query_embedding = self.embedding_model(query)[0][0] similarities = cosine_similarity([query_embedding], self.document_embeddings)[0] filtered_indices = [i for i, similarity in enumerate(similarities) if similarity >= self.settings.rag_similarity_threshold] top_n_indices = sorted(filtered_indices, key=lambda i: similarities[i], reverse=True)[:self.settings.rag_num_results] return [self.documents[i] for i in top_n_indices] def generate(self, input_text: str) -> str: """ Generate text based on the input string using the teapotllm model. Parameters: input_text (str): The text prompt to generate a response for. Returns: str: The generated output from the model. """ inputs = self.tokenizer(input_text, return_tensors="pt", padding=True, truncation=True) with torch.no_grad(): outputs = self.quantized_model.generate(inputs["input_ids"], max_length=512) result = self.tokenizer.decode(outputs[0], skip_special_tokens=True) if self.settings.log_level == "debug": print(input_text) print(result) return result def query(self, query: str, context: str = "") -> str: """ Handle a query and context, using RAG if no context is provided, and return a generated response. Parameters: query (str): The query string to be answered. context (str): The context to guide the response. Defaults to an empty string. Returns: str: The generated response based on the input query and context. """ if self.settings.use_rag and not context: context = "\n".join(self.rag(query)) # Perform RAG if no context is provided input_text = f"Context: {context}\nQuery: {query}" return self.generate(input_text) def chat(self, conversation_history: List[dict]) -> str: """ Engage in a chat by taking a list of previous messages and generating a response. Parameters: conversation_history (List[dict]): A list of previous messages, each containing 'content'. Returns: str: The generated response based on the conversation history. """ chat_history = "".join([message['content'] + "\n" for message in conversation_history]) if self.settings.use_rag: context_documents = self.rag(chat_history) # Perform RAG on the conversation history context = "\n".join(context_documents) chat_history = f"Context: {context}\n" + chat_history return self.generate(chat_history + "\n" + "agent:") def extract(self, class_annotation: BaseModel, query: str = "", context: str = "") -> BaseModel: """ Extract fields from a Pydantic class annotation by querying and processing each field. Parameters: class_annotation (BaseModel): The Pydantic class to extract fields from. query (str): The query string to guide the extraction. Defaults to an empty string. context (str): Optional context for the query. Returns: BaseModel: An instance of the provided Pydantic class with extracted field values. """ if self.settings.use_rag: context_documents = self.rag(query) context = "\n".join(context_documents) + context output = {} for field_name, field in class_annotation.__fields__.items(): type_annotation = field.annotation description = field.description description_annotation = f"({description})" if description else "" result = self.query(f"Extract the field {field_name} {description_annotation} to a {type_annotation}", context=context) # Process result based on field type if type_annotation == bool: parsed_result = ( True if re.search(r'\b(yes|true)\b', result, re.IGNORECASE) else (False if re.search(r'\b(no|false)\b', result, re.IGNORECASE) else None) ) elif type_annotation in [int, float]: parsed_result = re.sub(r'[^0-9.]', '', result) if parsed_result: try: parsed_result = type_annotation(parsed_result) except Exception: parsed_result = None else: parsed_result = None elif type_annotation == str: parsed_result = result.strip() else: raise ValueError(f"Unsupported type annotation: {type_annotation}") output[field_name] = parsed_result return class_annotation(**output) ##### End Library Code def log_time(func): def wrapper(*args, **kwargs): start_time = time.time() result = func(*args, **kwargs) end_time = time.time() print(f"{func.__name__} executed in {end_time - start_time:.4f} seconds") return result return wrapper default_documents = [] API_KEY = os.environ.get("brave_api_key") @log_time def brave_search(query, count=3): url = "https://api.search.brave.com/res/v1/web/search" headers = {"Accept": "application/json", "X-Subscription-Token": API_KEY} params = {"q": query, "count": count} response = requests.get(url, headers=headers, params=params) if response.status_code == 200: results = response.json().get("web", {}).get("results", []) print(results) return [(res["title"], res["description"], res["url"]) for res in results] else: print(f"Error: {response.status_code}, {response.text}") return [] @traceable @log_time def query_teapot(prompt, context, user_input, teapot_ai): response = teapot_ai.query( context=prompt+"\n"+context, query=user_input ) return response @log_time def handle_chat(user_input, teapot_ai): results = brave_search(user_input) documents = [desc.replace('','').replace('','') for _, desc, _ in results] st.sidebar.write("---") st.sidebar.write("## RAG Documents") for (title, description, url) in results: # Display Results st.sidebar.write(f"## {title}") st.sidebar.write(f"{description.replace('','').replace('','')}") st.sidebar.write(f"[Source]({url})") st.sidebar.write("---") context = "\n".join(documents) prompt = "You are Teapot, an open-source AI assistant optimized for low-end devices, providing short, accurate responses without hallucinating while excelling at information extraction and text summarization." response = query_teapot(prompt, context, user_input, teapot_ai) return response def suggestion_button(suggestion_text, teapot_ai): if st.button(suggestion_text): handle_chat(suggestion_text, teapot_ai) @log_time def hash_documents(documents): return hashlib.sha256("\n".join(documents).encode("utf-8")).hexdigest() def main(): st.set_page_config(page_title="TeapotAI Chat", page_icon=":robot_face:", layout="wide") st.sidebar.header("Retrieval Augmented Generation") user_documents = st.sidebar.text_area("Enter documents, each on a new line", value="\n".join(default_documents)) documents = [doc.strip() for doc in user_documents.split("\n") if doc.strip()] new_documents_hash = hash_documents(documents) if "documents_hash" not in st.session_state or st.session_state.documents_hash != new_documents_hash: with st.spinner('Loading Model and Embeddings...'): start_time = time.time() teapot_ai = TeapotAI(documents=documents or default_documents, settings=TeapotAISettings(rag_num_results=3)) end_time = time.time() print(f"Model loaded in {end_time - start_time:.4f} seconds") st.session_state.documents_hash = new_documents_hash st.session_state.teapot_ai = teapot_ai else: teapot_ai = st.session_state.teapot_ai if "messages" not in st.session_state: st.session_state.messages = [{"role": "assistant", "content": "Hi, I am Teapot AI, how can I help you?"}] for message in st.session_state.messages: with st.chat_message(message["role"]): st.markdown(message["content"]) user_input = st.chat_input("Ask me anything") s1, s2, s3 = st.columns([1, 2, 3]) with s1: suggestion_button("Tell me about the varieties of tea", teapot_ai) with s2: suggestion_button("Who was born first, Alan Turing or John von Neumann?", teapot_ai) with s3: suggestion_button("Extract Google's stock price", teapot_ai) if user_input: with st.chat_message("user"): st.markdown(user_input) st.session_state.messages.append({"role": "user", "content": user_input}) with st.spinner('Generating Response...'): response = handle_chat(user_input, teapot_ai) with st.chat_message("assistant"): st.markdown(response) st.session_state.messages.append({"role": "assistant", "content": response}) st.markdown("### Suggested Questions") if __name__ == "__main__": main()