Spaces:
Running
on
T4
Running
on
T4
File size: 14,648 Bytes
a94de35 0aad6fd e115381 790fcbd fa73971 d6dc06a 968f358 492a988 e115381 0aad6fd 85c189a 0aad6fd e115381 0aad6fd e115381 0aad6fd a94de35 58a8659 e115381 bd5c379 58a8659 bd5c379 58a8659 a94de35 e115381 289d0f6 51db61d 0aad6fd e115381 8d6c903 4dcc069 e48e44f 4dcc069 950465b e115381 f1ebe90 289d0f6 62c86e1 577cbf8 a94de35 45f3194 a94de35 e115381 950465b 289d0f6 e115381 51db61d ce7bb20 e115381 f1ebe90 580cfe1 a2373c5 fa73971 08dec69 fa73971 f1ebe90 d366837 e04907a 4a56dea ce7bb20 a94de35 b962028 ce7bb20 4a56dea a94de35 4a56dea ce7bb20 4a56dea a94de35 b962028 ce7bb20 4a56dea ce7bb20 e04907a ce7bb20 3366ac8 ce7bb20 a94de35 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 |
import streamlit as st
from teapotai import TeapotAI, TeapotAISettings
import hashlib
import os
import requests
import time
from langsmith import traceable
import random
### Begin Library Code
from transformers import pipeline
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from pydantic import BaseModel
from typing import List, Optional
from tqdm import tqdm
import re
import os
class TeapotAISettings(BaseModel):
"""
Pydantic settings model for TeapotAI configuration.
Attributes:
use_rag (bool): Whether to use RAG (Retrieve and Generate).
rag_num_results (int): Number of top documents to retrieve based on similarity.
rag_similarity_threshold (float): Similarity threshold for document relevance.
verbose (bool): Whether to print verbose updates.
log_level (str): The log level for the application (e.g., "info", "debug").
"""
use_rag: bool = True # Whether to use RAG (Retrieve and Generate)
rag_num_results: int = 3 # Number of top documents to retrieve based on similarity
rag_similarity_threshold: float = 0.5 # Similarity threshold for document relevance
verbose: bool = True # Whether to print verbose updates
log_level: str = "info" # Log level setting (e.g., 'info', 'debug')
class TeapotAI:
"""
TeapotAI class that interacts with a language model for text generation and retrieval tasks.
Attributes:
model (str): The model identifier.
model_revision (Optional[str]): The revision/version of the model.
api_key (Optional[str]): API key for accessing the model (if required).
settings (TeapotAISettings): Configuration settings for the AI instance.
generator (callable): The pipeline for text generation.
embedding_model (callable): The pipeline for feature extraction (document embeddings).
documents (List[str]): List of documents for retrieval.
document_embeddings (np.ndarray): Embeddings for the provided documents.
"""
def __init__(self, model_revision: Optional[str] = None, api_key: Optional[str] = None,
documents: List[str] = [], settings: TeapotAISettings = TeapotAISettings()):
"""
Initializes the TeapotAI class with optional model_revision and api_key.
Parameters:
model_revision (Optional[str]): The revision/version of the model to use.
api_key (Optional[str]): The API key for accessing the model if needed.
documents (List[str]): A list of documents for retrieval. Defaults to an empty list.
settings (TeapotAISettings): The settings configuration (defaults to TeapotAISettings()).
"""
self.model = "teapotai/teapotllm"
self.model_revision = model_revision
self.api_key = api_key
self.settings = settings
if self.settings.verbose:
print(""" _____ _ _ ___ __o__ _;;
|_ _|__ __ _ _ __ ___ | |_ / \ |_ _| __ /-___-\__/ /
| |/ _ \/ _` | '_ \ / _ \| __| / _ \ | | ( | |__/
| | __/ (_| | |_) | (_) | |_ / ___ \ | | \_|~~~~~~~|
|_|\___|\__,_| .__/ \___/ \__/ /_/ \_\___| \_____/
|_| """)
if self.settings.verbose:
print(f"Loading Model: {self.model} Revision: {self.model_revision or 'Latest'}")
self.generator = pipeline("text2text-generation", model=self.model, revision=self.model_revision) if model_revision else pipeline("text2text-generation", model=self.model)
self.documents = documents
if self.settings.use_rag and self.documents:
self.embedding_model = pipeline("feature-extraction", model="teapotai/teapotembedding")
self.document_embeddings = self._generate_document_embeddings(self.documents)
def _generate_document_embeddings(self, documents: List[str]) -> np.ndarray:
"""
Generate embeddings for the provided documents using the embedding model.
Parameters:
documents (List[str]): A list of document strings to generate embeddings for.
Returns:
np.ndarray: A NumPy array of document embeddings.
"""
embeddings = []
if self.settings.verbose:
print("Generating embeddings for documents...")
for doc in tqdm(documents, desc="Document Embedding", unit="doc"):
embeddings.append(self.embedding_model(doc)[0][0])
else:
for doc in documents:
embeddings.append(self.embedding_model(doc)[0][0])
return np.array(embeddings)
def rag(self, query: str) -> List[str]:
"""
Perform RAG (Retrieve and Generate) by finding the most relevant documents based on cosine similarity.
Parameters:
query (str): The query string to find relevant documents for.
Returns:
List[str]: A list of the top N most relevant documents.
"""
if not self.settings.use_rag or not self.documents:
return []
query_embedding = self.embedding_model(query)[0][0]
similarities = cosine_similarity([query_embedding], self.document_embeddings)[0]
filtered_indices = [i for i, similarity in enumerate(similarities) if similarity >= self.settings.rag_similarity_threshold]
top_n_indices = sorted(filtered_indices, key=lambda i: similarities[i], reverse=True)[:self.settings.rag_num_results]
return [self.documents[i] for i in top_n_indices]
def generate(self, input_text: str) -> str:
"""
Generate text based on the input string using the teapotllm model.
Parameters:
input_text (str): The text prompt to generate a response for.
Returns:
str: The generated output from the model.
"""
result = self.generator(input_text, max_length=512)[0].get("generated_text")
if self.settings.log_level == "debug":
print(input_text)
print(result)
return result
def query(self, query: str, context: str = "") -> str:
"""
Handle a query and context, using RAG if no context is provided, and return a generated response.
Parameters:
query (str): The query string to be answered.
context (str): The context to guide the response. Defaults to an empty string.
Returns:
str: The generated response based on the input query and context.
"""
if self.settings.use_rag and not context:
context = "\n".join(self.rag(query)) # Perform RAG if no context is provided
input_text = f"Context: {context}\nQuery: {query}"
return self.generate(input_text)
def chat(self, conversation_history: List[dict]) -> str:
"""
Engage in a chat by taking a list of previous messages and generating a response.
Parameters:
conversation_history (List[dict]): A list of previous messages, each containing 'content'.
Returns:
str: The generated response based on the conversation history.
"""
chat_history = "".join([message['content'] + "\n" for message in conversation_history])
if self.settings.use_rag:
context_documents = self.rag(chat_history) # Perform RAG on the conversation history
context = "\n".join(context_documents)
chat_history = f"Context: {context}\n" + chat_history
return self.generate(chat_history + "\n" + "agent:")
def extract(self, class_annotation: BaseModel, query: str = "", context: str = "") -> BaseModel:
"""
Extract fields from a Pydantic class annotation by querying and processing each field.
Parameters:
class_annotation (BaseModel): The Pydantic class to extract fields from.
query (str): The query string to guide the extraction. Defaults to an empty string.
context (str): Optional context for the query.
Returns:
BaseModel: An instance of the provided Pydantic class with extracted field values.
"""
if self.settings.use_rag:
context_documents = self.rag(query)
context = "\n".join(context_documents) + context
output = {}
for field_name, field in class_annotation.__fields__.items():
type_annotation = field.annotation
description = field.description
description_annotation = f"({description})" if description else ""
result = self.query(f"Extract the field {field_name} {description_annotation} to a {type_annotation}", context=context)
# Process result based on field type
if type_annotation == bool:
parsed_result = (
True if re.search(r'\b(yes|true)\b', result, re.IGNORECASE)
else (False if re.search(r'\b(no|false)\b', result, re.IGNORECASE) else None)
)
elif type_annotation in [int, float]:
parsed_result = re.sub(r'[^0-9.]', '', result)
if parsed_result:
try:
parsed_result = type_annotation(parsed_result)
except Exception:
parsed_result = None
else:
parsed_result = None
elif type_annotation == str:
parsed_result = result.strip()
else:
raise ValueError(f"Unsupported type annotation: {type_annotation}")
output[field_name] = parsed_result
return class_annotation(**output)
### End Library Code
def log_time(func):
def wrapper(*args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
print(f"{func.__name__} executed in {end_time - start_time:.4f} seconds")
return result
return wrapper
default_documents = []
API_KEY = os.environ.get("brave_api_key")
@log_time
def brave_search(query, count=3):
url = "https://api.search.brave.com/res/v1/web/search"
headers = {"Accept": "application/json", "X-Subscription-Token": API_KEY}
params = {"q": query, "count": count}
response = requests.get(url, headers=headers, params=params)
if response.status_code == 200:
results = response.json().get("web", {}).get("results", [])
print(results)
return [(res["title"], res["description"], res["url"]) for res in results]
else:
print(f"Error: {response.status_code}, {response.text}")
return []
@traceable
@log_time
def query_teapot(prompt, context, user_input, teapot_ai):
response = teapot_ai.query(
context=prompt+"\n"+context,
query=user_input
)
return response
@log_time
def handle_chat(user_prompt, user_input, teapot_ai):
with st.chat_message("user"):
st.markdown(user_input)
st.session_state.messages.append({"role": "user", "content": user_input})
results = brave_search(user_input)
documents = [desc.replace('<strong>','').replace('</strong>','') for _, desc, _ in results]
st.sidebar.write("---")
st.sidebar.write("## RAG Documents")
for (title, description, url) in results:
# Display Results
st.sidebar.write(f"## {title}")
st.sidebar.write(f"{description.replace('<strong>','').replace('</strong>','')}")
st.sidebar.write(f"[Source]({url})")
st.sidebar.write("---")
context = "\n".join(documents)
prompt = """You are Teapot, an open-source AI assistant optimized for low-end devices, providing short, accurate responses without hallucinating while excelling at information extraction and text summarization. If a user asks who you are reply "I am Teapot"."""
response = query_teapot(prompt, context+user_prompt, user_input, teapot_ai)
with st.chat_message("assistant"):
st.markdown(response)
st.session_state.messages.append({"role": "assistant", "content": response})
return response
def main():
st.set_page_config(page_title="TeapotAI Chat", page_icon=":robot_face:", layout="wide")
st.sidebar.header("Retrieval Augmented Generation")
user_prompt = st.sidebar.text_area("Enter prompt, leave empty for search")
teapot_ai = TeapotAI(documents=[], settings=TeapotAISettings(rag_num_results=3, log_level="debug"))
list1 = ["Tell me about teapotllm", "What is Teapot AI?","What devices can Teapot run on?","Who are you?"]
list2 = ["Who invented quantum mechanics?", "Who are the authors of attention is all you need", "Tell me about popular places to travel in France","Summarize the book irobot", "Explain artificial intelligence","what are the key ingredients of bouillabaisse"]
list3 = ["Extract the year Google was founded", "Extract the last name of the father of artificial intelligence", "Output the capital of New York","Extarct the city where the louvre is located","Find the chemical symbol for gold","Extract the name of the woman who was the first computer programmer"]
# Randomly select one from each list
random_selection = [random.choice(list1), random.choice(list2), random.choice(list3)]
choice1 = random.choice(list1)
choice2 = random.choice(list2)
choice3 = random.choice(list3)
s1, s2, s3 = st.columns([1, 1, 1])
user_suggested_input = None
with s1:
if st.button(choice1, use_container_width=True):
user_suggested_input = choice1
with s2:
if st.button(choice2, use_container_width=True):
user_suggested_input = choice2
with s3:
if st.button(choice3, use_container_width=True):
user_suggested_input = choice3
if "messages" not in st.session_state:
st.session_state.messages = [{"role": "assistant", "content": "Hi, I am Teapot AI, how can I help you?"}]
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
user_input = st.chat_input("Ask me anything")
if user_input:
with st.spinner('Generating Response...'):
response = handle_chat(user_prompt, user_suggested_input or user_input, teapot_ai)
if __name__ == "__main__":
main()
|