File size: 15,209 Bytes
a94de35
 
 
0aad6fd
 
e115381
790fcbd
0aad6fd
d6dc06a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e115381
 
 
 
 
 
 
 
 
0aad6fd
 
85c189a
0aad6fd
e115381
0aad6fd
 
 
e115381
0aad6fd
 
 
 
 
 
 
 
 
 
a94de35
58a8659
e115381
bd5c379
58a8659
bd5c379
58a8659
 
 
a94de35
e115381
a94de35
0aad6fd
e115381
 
8d6c903
 
4dcc069
 
 
e48e44f
4dcc069
 
950465b
e115381
bd5c379
d2bb19e
577cbf8
a94de35
 
 
 
 
 
e115381
a94de35
 
 
 
 
e115381
950465b
e115381
a94de35
e115381
a94de35
e115381
a94de35
 
e115381
a94de35
e115381
 
a94de35
 
 
 
 
e115381
a94de35
 
 
 
 
 
e115381
341437d
e115381
a94de35
 
341437d
a94de35
341437d
a94de35
341437d
a94de35
 
 
 
e115381
a94de35
 
 
e115381
a94de35
 
e115381
a94de35
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
import streamlit as st
from teapotai import TeapotAI, TeapotAISettings
import hashlib
import os
import requests
import time
from langsmith import traceable


##### Begin Library Code
from transformers import pipeline
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from pydantic import BaseModel
from typing import List, Optional
from tqdm import tqdm
import re
import os


class TeapotAISettings(BaseModel):
    """
    Pydantic settings model for TeapotAI configuration.
    
    Attributes:
        use_rag (bool): Whether to use RAG (Retrieve and Generate).
        rag_num_results (int): Number of top documents to retrieve based on similarity.
        rag_similarity_threshold (float): Similarity threshold for document relevance.
        verbose (bool): Whether to print verbose updates.
        log_level (str): The log level for the application (e.g., "info", "debug").
    """
    use_rag: bool = True  # Whether to use RAG (Retrieve and Generate)
    rag_num_results: int = 3  # Number of top documents to retrieve based on similarity
    rag_similarity_threshold: float = 0.5  # Similarity threshold for document relevance
    verbose: bool = True  # Whether to print verbose updates
    log_level: str = "info"  # Log level setting (e.g., 'info', 'debug')


class TeapotAI:
    """
    TeapotAI class that interacts with a language model for text generation and retrieval tasks.
    
    Attributes:
        model (str): The model identifier.
        model_revision (Optional[str]): The revision/version of the model.
        api_key (Optional[str]): API key for accessing the model (if required).
        settings (TeapotAISettings): Configuration settings for the AI instance.
        generator (callable): The pipeline for text generation.
        embedding_model (callable): The pipeline for feature extraction (document embeddings).
        documents (List[str]): List of documents for retrieval.
        document_embeddings (np.ndarray): Embeddings for the provided documents.
    """
    
    def __init__(self, model_revision: Optional[str] = None, api_key: Optional[str] = None,
                 documents: List[str] = [], settings: TeapotAISettings = TeapotAISettings()):
        """
        Initializes the TeapotAI class with optional model_revision and api_key.

        Parameters:
            model_revision (Optional[str]): The revision/version of the model to use.
            api_key (Optional[str]): The API key for accessing the model if needed.
            documents (List[str]): A list of documents for retrieval. Defaults to an empty list.
            settings (TeapotAISettings): The settings configuration (defaults to TeapotAISettings()).
        """
        self.model = "teapotai/teapotllm"
        self.model_revision = model_revision
        self.api_key = api_key
        self.settings = settings
        
        if self.settings.verbose:
            print(""" _____                      _         _    ___        __o__    _;; 
|_   _|__  __ _ _ __   ___ | |_      / \  |_ _|   __ /-___-\__/ /
  | |/ _ \/ _` | '_ \ / _ \| __|    / _ \  | |   (  |       |__/
  | |  __/ (_| | |_) | (_) | |_    / ___ \ | |    \_|~~~~~~~|
  |_|\___|\__,_| .__/ \___/ \__/  /_/   \_\___|      \_____/
               |_|   """)
        
        if self.settings.verbose:
            print(f"Loading Model: {self.model} Revision: {self.model_revision or 'Latest'}")
        
        # self.generator = pipeline("text2text-generation", model=self.model, revision=self.model_revision) if model_revision else pipeline("text2text-generation", model=self.model)

        self.tokenizer = AutoTokenizer.from_pretrained(self.model)
        model = AutoModelForSeq2SeqLM.from_pretrained(self.model)
        model.eval()

         # Quantization settings 
        quantization_dtype = torch.qint8  # or torch.float16
        quantization_config = torch.quantization.get_default_qconfig('fbgemm')  # or 'onednn'

        self.quantized_model = torch.quantization.quantize_dynamic(
            model, {torch.nn.Linear}, dtype=quantization_dtype
        )

        self.documents = documents
        
        if self.settings.use_rag and self.documents:
            self.embedding_model = pipeline("feature-extraction", model="teapotai/teapotembedding")
            self.document_embeddings = self._generate_document_embeddings(self.documents)

    def _generate_document_embeddings(self, documents: List[str]) -> np.ndarray:
        """
        Generate embeddings for the provided documents using the embedding model.

        Parameters:
            documents (List[str]): A list of document strings to generate embeddings for.

        Returns:
            np.ndarray: A NumPy array of document embeddings.
        """
        embeddings = []
        
        if self.settings.verbose:
            print("Generating embeddings for documents...")
            for doc in tqdm(documents, desc="Document Embedding", unit="doc"):
                embeddings.append(self.embedding_model(doc)[0][0])
        else:
            for doc in documents:
                embeddings.append(self.embedding_model(doc)[0][0])
                
        return np.array(embeddings)

    def rag(self, query: str) -> List[str]:
        """
        Perform RAG (Retrieve and Generate) by finding the most relevant documents based on cosine similarity.

        Parameters:
            query (str): The query string to find relevant documents for.

        Returns:
            List[str]: A list of the top N most relevant documents.
        """
        if not self.settings.use_rag or not self.documents:
            return []

        query_embedding = self.embedding_model(query)[0][0]
        similarities = cosine_similarity([query_embedding], self.document_embeddings)[0]

        filtered_indices = [i for i, similarity in enumerate(similarities) if similarity >= self.settings.rag_similarity_threshold]
        top_n_indices = sorted(filtered_indices, key=lambda i: similarities[i], reverse=True)[:self.settings.rag_num_results]

        return [self.documents[i] for i in top_n_indices]

    def generate(self, input_text: str) -> str:
        """
        Generate text based on the input string using the teapotllm model.

        Parameters:
            input_text (str): The text prompt to generate a response for.

        Returns:
            str: The generated output from the model.
        """
        
        inputs = self.tokenizer(input_text, return_tensors="pt", padding=True, truncation=True)

          
        with torch.no_grad():
            outputs = self.quantized_model.generate(inputs["input_ids"], max_length=512)
        

        result = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        

        if self.settings.log_level == "debug":
            print(input_text)
            print(result)
        
        return result

    def query(self, query: str, context: str = "") -> str:
        """
        Handle a query and context, using RAG if no context is provided, and return a generated response.

        Parameters:
            query (str): The query string to be answered.
            context (str): The context to guide the response. Defaults to an empty string.

        Returns:
            str: The generated response based on the input query and context.
        """
        if self.settings.use_rag and not context:
            context = "\n".join(self.rag(query))  # Perform RAG if no context is provided
        
        input_text = f"Context: {context}\nQuery: {query}"
        return self.generate(input_text)

    def chat(self, conversation_history: List[dict]) -> str:
        """
        Engage in a chat by taking a list of previous messages and generating a response.

        Parameters:
            conversation_history (List[dict]): A list of previous messages, each containing 'content'.

        Returns:
            str: The generated response based on the conversation history.
        """
        chat_history = "".join([message['content'] + "\n" for message in conversation_history])

        if self.settings.use_rag:
            context_documents = self.rag(chat_history)  # Perform RAG on the conversation history
            context = "\n".join(context_documents)
            chat_history = f"Context: {context}\n" + chat_history

        return self.generate(chat_history + "\n" + "agent:")

    def extract(self, class_annotation: BaseModel, query: str = "", context: str = "") -> BaseModel:
        """
        Extract fields from a Pydantic class annotation by querying and processing each field.

        Parameters:
            class_annotation (BaseModel): The Pydantic class to extract fields from.
            query (str): The query string to guide the extraction. Defaults to an empty string.
            context (str): Optional context for the query.

        Returns:
            BaseModel: An instance of the provided Pydantic class with extracted field values.
        """
        if self.settings.use_rag:
            context_documents = self.rag(query)
            context = "\n".join(context_documents) + context
        
        output = {}
        for field_name, field in class_annotation.__fields__.items():
            type_annotation = field.annotation
            description = field.description
            description_annotation = f"({description})" if description else ""

            result = self.query(f"Extract the field {field_name} {description_annotation} to a {type_annotation}", context=context)

            # Process result based on field type
            if type_annotation == bool:
                parsed_result = (
                    True if re.search(r'\b(yes|true)\b', result, re.IGNORECASE)
                    else (False if re.search(r'\b(no|false)\b', result, re.IGNORECASE) else None)
                )
            elif type_annotation in [int, float]:
                parsed_result = re.sub(r'[^0-9.]', '', result)
                if parsed_result:
                    try:
                        parsed_result = type_annotation(parsed_result)
                    except Exception:
                        parsed_result = None
                else:
                    parsed_result = None
            elif type_annotation == str:
                parsed_result = result.strip()
            else:
                raise ValueError(f"Unsupported type annotation: {type_annotation}")

            output[field_name] = parsed_result
        
        return class_annotation(**output)
        
##### End Library Code
def log_time(func):
    def wrapper(*args, **kwargs):
        start_time = time.time()
        result = func(*args, **kwargs)
        end_time = time.time()
        print(f"{func.__name__} executed in {end_time - start_time:.4f} seconds")
        return result
    return wrapper

default_documents = []

API_KEY = os.environ.get("brave_api_key")

@log_time
def brave_search(query, count=3):
    url = "https://api.search.brave.com/res/v1/web/search"
    headers = {"Accept": "application/json", "X-Subscription-Token": API_KEY}
    params = {"q": query, "count": count}
    
    response = requests.get(url, headers=headers, params=params)
    
    if response.status_code == 200:
        results = response.json().get("web", {}).get("results", [])
        print(results)
        return [(res["title"], res["description"], res["url"]) for res in results]
    else:
        print(f"Error: {response.status_code}, {response.text}")
        return []

@traceable 
@log_time
def query_teapot(prompt, context, user_input, teapot_ai):
    response = teapot_ai.query(
        context=prompt+"\n"+context,
        query=user_input
    )
    return response

@log_time
def handle_chat(user_input, teapot_ai):
    results = brave_search(user_input)
    
    documents = [desc.replace('<strong>','').replace('</strong>','') for _, desc, _ in results]
    st.sidebar.write("---")
    st.sidebar.write("## RAG Documents")
    for (title, description, url) in results:
        # Display Results 
        st.sidebar.write(f"## {title}")
        st.sidebar.write(f"{description.replace('<strong>','').replace('</strong>','')}")
        st.sidebar.write(f"[Source]({url})")
        st.sidebar.write("---")

    context = "\n".join(documents)
    prompt = "You are Teapot, an open-source AI assistant optimized for low-end devices, providing short, accurate responses without hallucinating while excelling at information extraction and text summarization."
    response = query_teapot(prompt, context, user_input, teapot_ai)
    
    return response

def suggestion_button(suggestion_text, teapot_ai):
    if st.button(suggestion_text):
        handle_chat(suggestion_text, teapot_ai)

@log_time
def hash_documents(documents):
    return hashlib.sha256("\n".join(documents).encode("utf-8")).hexdigest()

def main():
    st.set_page_config(page_title="TeapotAI Chat", page_icon=":robot_face:", layout="wide")
    
    st.sidebar.header("Retrieval Augmented Generation")
    user_documents = st.sidebar.text_area("Enter documents, each on a new line", value="\n".join(default_documents))
    
    documents = [doc.strip() for doc in user_documents.split("\n") if doc.strip()]
    new_documents_hash = hash_documents(documents)
    
    if "documents_hash" not in st.session_state or st.session_state.documents_hash != new_documents_hash:
        with st.spinner('Loading Model and Embeddings...'):
            start_time = time.time()
            teapot_ai = TeapotAI(documents=documents or default_documents, settings=TeapotAISettings(rag_num_results=3))
            end_time = time.time()
            print(f"Model loaded in {end_time - start_time:.4f} seconds")
        
        st.session_state.documents_hash = new_documents_hash
        st.session_state.teapot_ai = teapot_ai
    else:
        teapot_ai = st.session_state.teapot_ai
    
    if "messages" not in st.session_state:
        st.session_state.messages = [{"role": "assistant", "content": "Hi, I am Teapot AI, how can I help you?"}]
    
    for message in st.session_state.messages:
        with st.chat_message(message["role"]):
            st.markdown(message["content"])
    
    user_input = st.chat_input("Ask me anything")
    
    s1, s2, s3 = st.columns([1, 2, 3])
    with s1:
        suggestion_button("Tell me about the varieties of tea", teapot_ai)
    with s2:
        suggestion_button("Who was born first, Alan Turing or John von Neumann?", teapot_ai)
    with s3:
        suggestion_button("Extract Google's stock price", teapot_ai)
    
    if user_input:
        with st.chat_message("user"):
            st.markdown(user_input)
        
        st.session_state.messages.append({"role": "user", "content": user_input})
        with st.spinner('Generating Response...'):
            response = handle_chat(user_input, teapot_ai)
        
        with st.chat_message("assistant"):
            st.markdown(response)
        
        st.session_state.messages.append({"role": "assistant", "content": response})
        st.markdown("### Suggested Questions")

if __name__ == "__main__":
    main()