zakerytclarke commited on
Commit
0c5711c
·
verified ·
1 Parent(s): ce7bb20

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -208
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import streamlit as st
2
- from teapotai import TeapotAI, TeapotAISettings
3
  import hashlib
4
  import os
5
  import requests
@@ -19,211 +18,14 @@ from tqdm import tqdm
19
  import re
20
  import os
21
 
 
22
 
23
- class TeapotAISettings(BaseModel):
24
- """
25
- Pydantic settings model for TeapotAI configuration.
26
-
27
- Attributes:
28
- use_rag (bool): Whether to use RAG (Retrieve and Generate).
29
- rag_num_results (int): Number of top documents to retrieve based on similarity.
30
- rag_similarity_threshold (float): Similarity threshold for document relevance.
31
- verbose (bool): Whether to print verbose updates.
32
- log_level (str): The log level for the application (e.g., "info", "debug").
33
- """
34
- use_rag: bool = True # Whether to use RAG (Retrieve and Generate)
35
- rag_num_results: int = 3 # Number of top documents to retrieve based on similarity
36
- rag_similarity_threshold: float = 0.5 # Similarity threshold for document relevance
37
- verbose: bool = True # Whether to print verbose updates
38
- log_level: str = "info" # Log level setting (e.g., 'info', 'debug')
39
-
40
 
41
- class TeapotAI:
42
- """
43
- TeapotAI class that interacts with a language model for text generation and retrieval tasks.
44
 
45
- Attributes:
46
- model (str): The model identifier.
47
- model_revision (Optional[str]): The revision/version of the model.
48
- api_key (Optional[str]): API key for accessing the model (if required).
49
- settings (TeapotAISettings): Configuration settings for the AI instance.
50
- generator (callable): The pipeline for text generation.
51
- embedding_model (callable): The pipeline for feature extraction (document embeddings).
52
- documents (List[str]): List of documents for retrieval.
53
- document_embeddings (np.ndarray): Embeddings for the provided documents.
54
- """
55
-
56
- def __init__(self, model_revision: Optional[str] = None, api_key: Optional[str] = None,
57
- documents: List[str] = [], settings: TeapotAISettings = TeapotAISettings()):
58
- """
59
- Initializes the TeapotAI class with optional model_revision and api_key.
60
- Parameters:
61
- model_revision (Optional[str]): The revision/version of the model to use.
62
- api_key (Optional[str]): The API key for accessing the model if needed.
63
- documents (List[str]): A list of documents for retrieval. Defaults to an empty list.
64
- settings (TeapotAISettings): The settings configuration (defaults to TeapotAISettings()).
65
- """
66
- self.model = "teapotai/teapotllm"
67
- self.model_revision = model_revision
68
- self.api_key = api_key
69
- self.settings = settings
70
-
71
- if self.settings.verbose:
72
- print(""" _____ _ _ ___ __o__ _;;
73
- |_ _|__ __ _ _ __ ___ | |_ / \ |_ _| __ /-___-\__/ /
74
- | |/ _ \/ _` | '_ \ / _ \| __| / _ \ | | ( | |__/
75
- | | __/ (_| | |_) | (_) | |_ / ___ \ | | \_|~~~~~~~|
76
- |_|\___|\__,_| .__/ \___/ \__/ /_/ \_\___| \_____/
77
- |_| """)
78
-
79
- if self.settings.verbose:
80
- print(f"Loading Model: {self.model} Revision: {self.model_revision or 'Latest'}")
81
-
82
- self.generator = pipeline("text2text-generation", model=self.model, revision=self.model_revision) if model_revision else pipeline("text2text-generation", model=self.model)
83
-
84
- self.documents = documents
85
-
86
- if self.settings.use_rag and self.documents:
87
- self.embedding_model = pipeline("feature-extraction", model="teapotai/teapotembedding")
88
- self.document_embeddings = self._generate_document_embeddings(self.documents)
89
-
90
- def _generate_document_embeddings(self, documents: List[str]) -> np.ndarray:
91
- """
92
- Generate embeddings for the provided documents using the embedding model.
93
- Parameters:
94
- documents (List[str]): A list of document strings to generate embeddings for.
95
- Returns:
96
- np.ndarray: A NumPy array of document embeddings.
97
- """
98
- embeddings = []
99
-
100
- if self.settings.verbose:
101
- print("Generating embeddings for documents...")
102
- for doc in tqdm(documents, desc="Document Embedding", unit="doc"):
103
- embeddings.append(self.embedding_model(doc)[0][0])
104
- else:
105
- for doc in documents:
106
- embeddings.append(self.embedding_model(doc)[0][0])
107
-
108
- return np.array(embeddings)
109
-
110
- def rag(self, query: str) -> List[str]:
111
- """
112
- Perform RAG (Retrieve and Generate) by finding the most relevant documents based on cosine similarity.
113
- Parameters:
114
- query (str): The query string to find relevant documents for.
115
- Returns:
116
- List[str]: A list of the top N most relevant documents.
117
- """
118
- if not self.settings.use_rag or not self.documents:
119
- return []
120
-
121
- query_embedding = self.embedding_model(query)[0][0]
122
- similarities = cosine_similarity([query_embedding], self.document_embeddings)[0]
123
-
124
- filtered_indices = [i for i, similarity in enumerate(similarities) if similarity >= self.settings.rag_similarity_threshold]
125
- top_n_indices = sorted(filtered_indices, key=lambda i: similarities[i], reverse=True)[:self.settings.rag_num_results]
126
-
127
- return [self.documents[i] for i in top_n_indices]
128
-
129
- def generate(self, input_text: str) -> str:
130
- """
131
- Generate text based on the input string using the teapotllm model.
132
- Parameters:
133
- input_text (str): The text prompt to generate a response for.
134
- Returns:
135
- str: The generated output from the model.
136
- """
137
-
138
-
139
- result = self.generator(input_text, max_length=512)[0].get("generated_text")
140
-
141
-
142
- if self.settings.log_level == "debug":
143
- print(input_text)
144
- print(result)
145
-
146
- return result
147
-
148
- def query(self, query: str, context: str = "") -> str:
149
- """
150
- Handle a query and context, using RAG if no context is provided, and return a generated response.
151
- Parameters:
152
- query (str): The query string to be answered.
153
- context (str): The context to guide the response. Defaults to an empty string.
154
- Returns:
155
- str: The generated response based on the input query and context.
156
- """
157
- if self.settings.use_rag and not context:
158
- context = "\n".join(self.rag(query)) # Perform RAG if no context is provided
159
-
160
- input_text = f"Context: {context}\nQuery: {query}"
161
- return self.generate(input_text)
162
-
163
- def chat(self, conversation_history: List[dict]) -> str:
164
- """
165
- Engage in a chat by taking a list of previous messages and generating a response.
166
- Parameters:
167
- conversation_history (List[dict]): A list of previous messages, each containing 'content'.
168
- Returns:
169
- str: The generated response based on the conversation history.
170
- """
171
- chat_history = "".join([message['content'] + "\n" for message in conversation_history])
172
-
173
- if self.settings.use_rag:
174
- context_documents = self.rag(chat_history) # Perform RAG on the conversation history
175
- context = "\n".join(context_documents)
176
- chat_history = f"Context: {context}\n" + chat_history
177
-
178
- return self.generate(chat_history + "\n" + "agent:")
179
 
180
- def extract(self, class_annotation: BaseModel, query: str = "", context: str = "") -> BaseModel:
181
- """
182
- Extract fields from a Pydantic class annotation by querying and processing each field.
183
- Parameters:
184
- class_annotation (BaseModel): The Pydantic class to extract fields from.
185
- query (str): The query string to guide the extraction. Defaults to an empty string.
186
- context (str): Optional context for the query.
187
- Returns:
188
- BaseModel: An instance of the provided Pydantic class with extracted field values.
189
- """
190
- if self.settings.use_rag:
191
- context_documents = self.rag(query)
192
- context = "\n".join(context_documents) + context
193
-
194
- output = {}
195
- for field_name, field in class_annotation.__fields__.items():
196
- type_annotation = field.annotation
197
- description = field.description
198
- description_annotation = f"({description})" if description else ""
199
-
200
- result = self.query(f"Extract the field {field_name} {description_annotation} to a {type_annotation}", context=context)
201
-
202
- # Process result based on field type
203
- if type_annotation == bool:
204
- parsed_result = (
205
- True if re.search(r'\b(yes|true)\b', result, re.IGNORECASE)
206
- else (False if re.search(r'\b(no|false)\b', result, re.IGNORECASE) else None)
207
- )
208
- elif type_annotation in [int, float]:
209
- parsed_result = re.sub(r'[^0-9.]', '', result)
210
- if parsed_result:
211
- try:
212
- parsed_result = type_annotation(parsed_result)
213
- except Exception:
214
- parsed_result = None
215
- else:
216
- parsed_result = None
217
- elif type_annotation == str:
218
- parsed_result = result.strip()
219
- else:
220
- raise ValueError(f"Unsupported type annotation: {type_annotation}")
221
-
222
- output[field_name] = parsed_result
223
-
224
- return class_annotation(**output)
225
-
226
- ### End Library Code
227
 
228
 
229
  def log_time(func):
@@ -235,7 +37,6 @@ def log_time(func):
235
  return result
236
  return wrapper
237
 
238
- default_documents = []
239
 
240
  API_KEY = os.environ.get("brave_api_key")
241
 
@@ -258,11 +59,26 @@ def brave_search(query, count=3):
258
  @traceable
259
  @log_time
260
  def query_teapot(prompt, context, user_input, teapot_ai):
261
- response = teapot_ai.query(
262
- context=prompt+"\n"+context,
263
- query=user_input
264
- )
265
- return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
 
267
  @log_time
268
  def handle_chat(user_prompt, user_input, teapot_ai):
 
1
  import streamlit as st
 
2
  import hashlib
3
  import os
4
  import requests
 
18
  import re
19
  import os
20
 
21
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
22
 
23
+ with st.spinner('Loading Model...'):
24
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
25
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
 
 
 
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
 
31
  def log_time(func):
 
37
  return result
38
  return wrapper
39
 
 
40
 
41
  API_KEY = os.environ.get("brave_api_key")
42
 
 
59
  @traceable
60
  @log_time
61
  def query_teapot(prompt, context, user_input, teapot_ai):
62
+ input_text = prompt + "\n" + context + "\n" + user_input
63
+
64
+ start_time = time.time()
65
+
66
+ inputs = tokenizer(input_text, return_tensors="pt")
67
+ input_length = inputs["input_ids"].shape[1]
68
+
69
+ output = model.generate(**inputs, max_new_tokens=max_new_tokens)
70
+
71
+ output_text = tokenizer.decode(output[0], skip_special_tokens=True)
72
+ total_length = output.shape[1] # Includes both input and output tokens
73
+ output_length = total_length - input_length # Extract output token count
74
+
75
+ end_time = time.time()
76
+
77
+ elapsed_time = end_time - start_time
78
+ tokens_per_second = total_length / elapsed_time if elapsed_time > 0 else float("inf")
79
+
80
+ return f"{output_text} ({tokens_per_second} tokens per second)"
81
+
82
 
83
  @log_time
84
  def handle_chat(user_prompt, user_input, teapot_ai):