zakerytclarke commited on
Commit
ba84c03
·
verified ·
1 Parent(s): 492a988

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -230
app.py CHANGED
@@ -7,237 +7,7 @@ import time
7
  from langsmith import traceable
8
  import random
9
 
10
- ##### Begin Library Code
11
 
12
- from transformers import pipeline
13
- import torch
14
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
15
- import numpy as np
16
- from sklearn.metrics.pairwise import cosine_similarity
17
- from pydantic import BaseModel
18
- from typing import List, Optional
19
- from tqdm import tqdm
20
- import re
21
- import os
22
-
23
-
24
- class TeapotAISettings(BaseModel):
25
- """
26
- Pydantic settings model for TeapotAI configuration.
27
-
28
- Attributes:
29
- use_rag (bool): Whether to use RAG (Retrieve and Generate).
30
- rag_num_results (int): Number of top documents to retrieve based on similarity.
31
- rag_similarity_threshold (float): Similarity threshold for document relevance.
32
- verbose (bool): Whether to print verbose updates.
33
- log_level (str): The log level for the application (e.g., "info", "debug").
34
- """
35
- use_rag: bool = True # Whether to use RAG (Retrieve and Generate)
36
- rag_num_results: int = 3 # Number of top documents to retrieve based on similarity
37
- rag_similarity_threshold: float = 0.5 # Similarity threshold for document relevance
38
- verbose: bool = True # Whether to print verbose updates
39
- log_level: str = "info" # Log level setting (e.g., 'info', 'debug')
40
-
41
-
42
- class TeapotAI:
43
- """
44
- TeapotAI class that interacts with a language model for text generation and retrieval tasks.
45
-
46
- Attributes:
47
- model (str): The model identifier.
48
- model_revision (Optional[str]): The revision/version of the model.
49
- api_key (Optional[str]): API key for accessing the model (if required).
50
- settings (TeapotAISettings): Configuration settings for the AI instance.
51
- generator (callable): The pipeline for text generation.
52
- embedding_model (callable): The pipeline for feature extraction (document embeddings).
53
- documents (List[str]): List of documents for retrieval.
54
- document_embeddings (np.ndarray): Embeddings for the provided documents.
55
- """
56
-
57
- def __init__(self, model_revision: Optional[str] = None, api_key: Optional[str] = None,
58
- documents: List[str] = [], settings: TeapotAISettings = TeapotAISettings()):
59
- """
60
- Initializes the TeapotAI class with optional model_revision and api_key.
61
-
62
- Parameters:
63
- model_revision (Optional[str]): The revision/version of the model to use.
64
- api_key (Optional[str]): The API key for accessing the model if needed.
65
- documents (List[str]): A list of documents for retrieval. Defaults to an empty list.
66
- settings (TeapotAISettings): The settings configuration (defaults to TeapotAISettings()).
67
- """
68
- self.model = "teapotai/teapotllm"
69
- self.model_revision = model_revision
70
- self.api_key = api_key
71
- self.settings = settings
72
-
73
- if self.settings.verbose:
74
- print(""" _____ _ _ ___ __o__ _;;
75
- |_ _|__ __ _ _ __ ___ | |_ / \ |_ _| __ /-___-\__/ /
76
- | |/ _ \/ _` | '_ \ / _ \| __| / _ \ | | ( | |__/
77
- | | __/ (_| | |_) | (_) | |_ / ___ \ | | \_|~~~~~~~|
78
- |_|\___|\__,_| .__/ \___/ \__/ /_/ \_\___| \_____/
79
- |_| """)
80
-
81
- if self.settings.verbose:
82
- print(f"Loading Model: {self.model} Revision: {self.model_revision or 'Latest'}")
83
-
84
- self.generator = pipeline("text2text-generation", model=self.model, revision=self.model_revision) if model_revision else pipeline("text2text-generation", model=self.model)
85
-
86
- self.documents = documents
87
-
88
- if self.settings.use_rag and self.documents:
89
- self.embedding_model = pipeline("feature-extraction", model="teapotai/teapotembedding")
90
- self.document_embeddings = self._generate_document_embeddings(self.documents)
91
-
92
- def _generate_document_embeddings(self, documents: List[str]) -> np.ndarray:
93
- """
94
- Generate embeddings for the provided documents using the embedding model.
95
-
96
- Parameters:
97
- documents (List[str]): A list of document strings to generate embeddings for.
98
-
99
- Returns:
100
- np.ndarray: A NumPy array of document embeddings.
101
- """
102
- embeddings = []
103
-
104
- if self.settings.verbose:
105
- print("Generating embeddings for documents...")
106
- for doc in tqdm(documents, desc="Document Embedding", unit="doc"):
107
- embeddings.append(self.embedding_model(doc)[0][0])
108
- else:
109
- for doc in documents:
110
- embeddings.append(self.embedding_model(doc)[0][0])
111
-
112
- return np.array(embeddings)
113
-
114
- def rag(self, query: str) -> List[str]:
115
- """
116
- Perform RAG (Retrieve and Generate) by finding the most relevant documents based on cosine similarity.
117
-
118
- Parameters:
119
- query (str): The query string to find relevant documents for.
120
-
121
- Returns:
122
- List[str]: A list of the top N most relevant documents.
123
- """
124
- if not self.settings.use_rag or not self.documents:
125
- return []
126
-
127
- query_embedding = self.embedding_model(query)[0][0]
128
- similarities = cosine_similarity([query_embedding], self.document_embeddings)[0]
129
-
130
- filtered_indices = [i for i, similarity in enumerate(similarities) if similarity >= self.settings.rag_similarity_threshold]
131
- top_n_indices = sorted(filtered_indices, key=lambda i: similarities[i], reverse=True)[:self.settings.rag_num_results]
132
-
133
- return [self.documents[i] for i in top_n_indices]
134
-
135
- def generate(self, input_text: str) -> str:
136
- """
137
- Generate text based on the input string using the teapotllm model.
138
-
139
- Parameters:
140
- input_text (str): The text prompt to generate a response for.
141
-
142
- Returns:
143
- str: The generated output from the model.
144
- """
145
-
146
-
147
- result = self.generator(input_text, max_length=512)[0].get("generated_text")
148
-
149
-
150
- if self.settings.log_level == "debug":
151
- print(input_text)
152
- print(result)
153
-
154
- return result
155
-
156
- def query(self, query: str, context: str = "") -> str:
157
- """
158
- Handle a query and context, using RAG if no context is provided, and return a generated response.
159
-
160
- Parameters:
161
- query (str): The query string to be answered.
162
- context (str): The context to guide the response. Defaults to an empty string.
163
-
164
- Returns:
165
- str: The generated response based on the input query and context.
166
- """
167
- if self.settings.use_rag and not context:
168
- context = "\n".join(self.rag(query)) # Perform RAG if no context is provided
169
-
170
- input_text = f"Context: {context}\nQuery: {query}"
171
- return self.generate(input_text)
172
-
173
- def chat(self, conversation_history: List[dict]) -> str:
174
- """
175
- Engage in a chat by taking a list of previous messages and generating a response.
176
-
177
- Parameters:
178
- conversation_history (List[dict]): A list of previous messages, each containing 'content'.
179
-
180
- Returns:
181
- str: The generated response based on the conversation history.
182
- """
183
- chat_history = "".join([message['content'] + "\n" for message in conversation_history])
184
-
185
- if self.settings.use_rag:
186
- context_documents = self.rag(chat_history) # Perform RAG on the conversation history
187
- context = "\n".join(context_documents)
188
- chat_history = f"Context: {context}\n" + chat_history
189
-
190
- return self.generate(chat_history + "\n" + "agent:")
191
-
192
- def extract(self, class_annotation: BaseModel, query: str = "", context: str = "") -> BaseModel:
193
- """
194
- Extract fields from a Pydantic class annotation by querying and processing each field.
195
-
196
- Parameters:
197
- class_annotation (BaseModel): The Pydantic class to extract fields from.
198
- query (str): The query string to guide the extraction. Defaults to an empty string.
199
- context (str): Optional context for the query.
200
-
201
- Returns:
202
- BaseModel: An instance of the provided Pydantic class with extracted field values.
203
- """
204
- if self.settings.use_rag:
205
- context_documents = self.rag(query)
206
- context = "\n".join(context_documents) + context
207
-
208
- output = {}
209
- for field_name, field in class_annotation.__fields__.items():
210
- type_annotation = field.annotation
211
- description = field.description
212
- description_annotation = f"({description})" if description else ""
213
-
214
- result = self.query(f"Extract the field {field_name} {description_annotation} to a {type_annotation}", context=context)
215
-
216
- # Process result based on field type
217
- if type_annotation == bool:
218
- parsed_result = (
219
- True if re.search(r'\b(yes|true)\b', result, re.IGNORECASE)
220
- else (False if re.search(r'\b(no|false)\b', result, re.IGNORECASE) else None)
221
- )
222
- elif type_annotation in [int, float]:
223
- parsed_result = re.sub(r'[^0-9.]', '', result)
224
- if parsed_result:
225
- try:
226
- parsed_result = type_annotation(parsed_result)
227
- except Exception:
228
- parsed_result = None
229
- else:
230
- parsed_result = None
231
- elif type_annotation == str:
232
- parsed_result = result.strip()
233
- else:
234
- raise ValueError(f"Unsupported type annotation: {type_annotation}")
235
-
236
- output[field_name] = parsed_result
237
-
238
- return class_annotation(**output)
239
-
240
- ##### End Library Code
241
  def log_time(func):
242
  def wrapper(*args, **kwargs):
243
  start_time = time.time()
 
7
  from langsmith import traceable
8
  import random
9
 
 
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def log_time(func):
12
  def wrapper(*args, **kwargs):
13
  start_time = time.time()