Update main.py
Browse files
main.py
CHANGED
@@ -46,6 +46,17 @@ class UnifiedModel(nn.Module):
|
|
46 |
logits = self.classifier(concatenated_hidden_states)
|
47 |
return logits
|
48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
class SyntheticDataset(Dataset):
|
50 |
def __init__(self, tokenizers, data):
|
51 |
self.tokenizers = tokenizers
|
@@ -68,6 +79,20 @@ class SyntheticDataset(Dataset):
|
|
68 |
|
69 |
@app.post("/process")
|
70 |
async def process(request: Request):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
data = await request.json()
|
72 |
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True)
|
73 |
|
@@ -102,41 +127,13 @@ async def process(request: Request):
|
|
102 |
if not user_data:
|
103 |
user_data = [{"text": "Sample text for automatic training.", "label": 0}]
|
104 |
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
evaluation_strategy="epoch",
|
111 |
-
learning_rate=5e-5,
|
112 |
-
per_device_train_batch_size=8,
|
113 |
-
per_device_eval_batch_size=8,
|
114 |
-
num_train_epochs=10,
|
115 |
-
weight_decay=0.01,
|
116 |
-
logging_steps=10,
|
117 |
-
optim="adamw_hf"
|
118 |
-
)
|
119 |
-
|
120 |
-
optimizer = AdamW(unified_model.parameters(), lr=training_args.learning_rate)
|
121 |
-
unified_model.train()
|
122 |
-
|
123 |
-
for epoch in range(training_args.num_train_epochs):
|
124 |
-
for batch in train_loader:
|
125 |
-
input_ids = [batch[f"input_ids_{name}"].to("cpu") for name in tokenizers.keys()]
|
126 |
-
attention_mask = [batch[f"attention_mask_{name}"].to("cpu") for name in tokenizers.keys()]
|
127 |
-
labels = batch["labels"].to("cpu")
|
128 |
-
outputs = unified_model(input_ids=input_ids, attention_mask=attention_mask)
|
129 |
-
loss = nn.CrossEntropyLoss()(outputs, labels)
|
130 |
-
loss.backward()
|
131 |
-
optimizer.step()
|
132 |
-
optimizer.zero_grad()
|
133 |
-
|
134 |
-
print(f"Epoch {epoch}, Loss {loss.item()}")
|
135 |
-
|
136 |
-
print("Training complete.")
|
137 |
|
138 |
-
|
139 |
-
return {"message": "Model trained and updated in Redis."}
|
140 |
|
141 |
elif data.get("predict"):
|
142 |
text = data['text']
|
@@ -155,6 +152,19 @@ async def process(request: Request):
|
|
155 |
|
156 |
@app.post("/external_answer")
|
157 |
async def external_answer(request: Request):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
data = await request.json()
|
159 |
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True)
|
160 |
|
@@ -162,26 +172,16 @@ async def external_answer(request: Request):
|
|
162 |
if not question:
|
163 |
raise HTTPException(status_code=400, detail="Question is required.")
|
164 |
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
model_data_bytes = redis_client.get(f"model:{model_name}")
|
169 |
-
tokenizer_data_bytes = redis_client.get(f"tokenizer:{tokenizer_name}")
|
170 |
-
|
171 |
-
if model_data_bytes:
|
172 |
-
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=2)
|
173 |
-
model.load_state_dict(torch.load(model_data_bytes))
|
174 |
-
else:
|
175 |
-
model = AutoModelForSequenceClassification.from_pretrained("gpt2")
|
176 |
|
|
|
177 |
if tokenizer_data_bytes:
|
178 |
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
179 |
tokenizer.add_tokens(json.loads(tokenizer_data_bytes))
|
180 |
else:
|
181 |
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
182 |
-
|
183 |
-
unified_model = UnifiedModel([model])
|
184 |
-
unified_model.to(torch.device("cpu"))
|
185 |
|
186 |
tokenized_input = tokenizer(question, return_tensors="pt")
|
187 |
input_ids = tokenized_input['input_ids']
|
@@ -192,45 +192,22 @@ async def external_answer(request: Request):
|
|
192 |
predicted_class = torch.argmax(logits, dim=-1).item()
|
193 |
response = {"answer": f"Response to '{question}' is class {predicted_class}"}
|
194 |
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
output_dir="memory",
|
201 |
-
evaluation_strategy="epoch",
|
202 |
-
learning_rate=5e-5,
|
203 |
-
per_device_train_batch_size=8,
|
204 |
-
per_device_eval_batch_size=8,
|
205 |
-
num_train_epochs=10,
|
206 |
-
weight_decay=0.01,
|
207 |
-
logging_steps=10,
|
208 |
-
optim="adamw_hf"
|
209 |
-
)
|
210 |
-
|
211 |
-
optimizer = AdamW(unified_model.parameters(), lr=training_args.learning_rate)
|
212 |
-
unified_model.train()
|
213 |
-
|
214 |
-
for epoch in range(training_args.num_train_epochs):
|
215 |
-
for batch in train_loader:
|
216 |
-
input_ids = [batch[f"input_ids_{name}"].to("cpu") for name in [tokenizer_name]]
|
217 |
-
attention_mask = [batch[f"attention_mask_{name}"].to("cpu") for name in [tokenizer_name]]
|
218 |
-
labels = batch["labels"].to("cpu")
|
219 |
-
outputs = unified_model(input_ids=input_ids, attention_mask=attention_mask)
|
220 |
-
loss = nn.CrossEntropyLoss()(outputs, labels)
|
221 |
-
loss.backward()
|
222 |
-
optimizer.step()
|
223 |
-
optimizer.zero_grad()
|
224 |
-
|
225 |
-
print(f"Epoch {epoch}, Loss {loss.item()}")
|
226 |
-
|
227 |
-
print("Extreme training complete.")
|
228 |
-
push_to_redis({model_name: model}, {tokenizer_name: tokenizer}, redis_client, model_name, tokenizer_name)
|
229 |
|
230 |
return response
|
231 |
|
232 |
@app.get("/")
|
233 |
async def get_home():
|
|
|
|
|
|
|
|
|
|
|
|
|
234 |
html_code = """
|
235 |
<!DOCTYPE html>
|
236 |
<html>
|
@@ -264,6 +241,18 @@ async def get_home():
|
|
264 |
return HTMLResponse(content=html_code)
|
265 |
|
266 |
def push_to_redis(models, tokenizers, redis_client, model_name, tokenizer_name):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
267 |
for model_name, model in models.items():
|
268 |
torch.save(model.state_dict(), model_name)
|
269 |
redis_client.set(f"model:{model_name}", open(model_name, "rb").read())
|
@@ -273,11 +262,14 @@ def push_to_redis(models, tokenizers, redis_client, model_name, tokenizer_name):
|
|
273 |
redis_client.set(f"tokenizer:{tokenizer_name}", json.dumps(tokens))
|
274 |
|
275 |
def continuous_training():
|
|
|
|
|
|
|
276 |
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True)
|
277 |
|
278 |
while True:
|
279 |
try:
|
280 |
-
data = redis_client.
|
281 |
if data:
|
282 |
data = json.loads(data)
|
283 |
unified_model = UnifiedModel.load_model_from_redis(redis_client)
|
@@ -302,13 +294,16 @@ def continuous_training():
|
|
302 |
print(f"Epoch {epoch}, Loss {loss.item()}")
|
303 |
|
304 |
push_to_redis(unified_model.models, data["tokenizers"], redis_client, "unified_model", "unified_tokenizer")
|
305 |
-
redis_client.delete("training_queue")
|
306 |
time.sleep(10)
|
307 |
except Exception as e:
|
308 |
print(f"Error in continuous training: {e}")
|
309 |
time.sleep(5)
|
310 |
|
311 |
if __name__ == "__main__":
|
312 |
-
|
|
|
|
|
|
|
|
|
313 |
import uvicorn
|
314 |
-
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
46 |
logits = self.classifier(concatenated_hidden_states)
|
47 |
return logits
|
48 |
|
49 |
+
@staticmethod
|
50 |
+
def load_model_from_redis(redis_client):
|
51 |
+
model_name = "unified_model"
|
52 |
+
model_data_bytes = redis_client.get(f"model:{model_name}")
|
53 |
+
if model_data_bytes:
|
54 |
+
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=2)
|
55 |
+
model.load_state_dict(torch.load(model_data_bytes))
|
56 |
+
else:
|
57 |
+
model = AutoModelForSequenceClassification.from_pretrained("gpt2")
|
58 |
+
return UnifiedModel([model])
|
59 |
+
|
60 |
class SyntheticDataset(Dataset):
|
61 |
def __init__(self, tokenizers, data):
|
62 |
self.tokenizers = tokenizers
|
|
|
79 |
|
80 |
@app.post("/process")
|
81 |
async def process(request: Request):
|
82 |
+
"""
|
83 |
+
Processes requests for training and prediction.
|
84 |
+
|
85 |
+
Args:
|
86 |
+
request (Request): The incoming request object.
|
87 |
+
|
88 |
+
Returns:
|
89 |
+
dict: A dictionary containing either a message indicating successful
|
90 |
+
training data submission or the model's prediction.
|
91 |
+
|
92 |
+
Raises:
|
93 |
+
HTTPException: If the request does not contain 'train' or 'predict'
|
94 |
+
keys.
|
95 |
+
"""
|
96 |
data = await request.json()
|
97 |
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True)
|
98 |
|
|
|
127 |
if not user_data:
|
128 |
user_data = [{"text": "Sample text for automatic training.", "label": 0}]
|
129 |
|
130 |
+
# Add user data to Redis queue for asynchronous training
|
131 |
+
redis_client.rpush("training_queue", json.dumps({
|
132 |
+
"tokenizers": {tokenizer_name: tokenizer.get_vocab()},
|
133 |
+
"data": user_data
|
134 |
+
}))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
|
136 |
+
return {"message": "Training data received. Model will be updated asynchronously."}
|
|
|
137 |
|
138 |
elif data.get("predict"):
|
139 |
text = data['text']
|
|
|
152 |
|
153 |
@app.post("/external_answer")
|
154 |
async def external_answer(request: Request):
|
155 |
+
"""
|
156 |
+
Provides an answer to a question using the unified model and triggers
|
157 |
+
asynchronous training with the new question-answer pair.
|
158 |
+
|
159 |
+
Args:
|
160 |
+
request (Request): The incoming request object containing the question.
|
161 |
+
|
162 |
+
Returns:
|
163 |
+
dict: A dictionary containing the answer to the question.
|
164 |
+
|
165 |
+
Raises:
|
166 |
+
HTTPException: If the request does not contain a 'question' key.
|
167 |
+
"""
|
168 |
data = await request.json()
|
169 |
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True)
|
170 |
|
|
|
172 |
if not question:
|
173 |
raise HTTPException(status_code=400, detail="Question is required.")
|
174 |
|
175 |
+
# Load the model and tokenizer from Redis
|
176 |
+
unified_model = UnifiedModel.load_model_from_redis(redis_client)
|
177 |
+
unified_model.to(torch.device("cpu"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
178 |
|
179 |
+
tokenizer_data_bytes = redis_client.get(f"tokenizer:unified_tokenizer")
|
180 |
if tokenizer_data_bytes:
|
181 |
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
182 |
tokenizer.add_tokens(json.loads(tokenizer_data_bytes))
|
183 |
else:
|
184 |
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
|
|
|
|
|
|
185 |
|
186 |
tokenized_input = tokenizer(question, return_tensors="pt")
|
187 |
input_ids = tokenized_input['input_ids']
|
|
|
192 |
predicted_class = torch.argmax(logits, dim=-1).item()
|
193 |
response = {"answer": f"Response to '{question}' is class {predicted_class}"}
|
194 |
|
195 |
+
# Asynchronously train on the new data point
|
196 |
+
redis_client.rpush("training_queue", json.dumps({
|
197 |
+
"tokenizers": {"unified_tokenizer": tokenizer.get_vocab()},
|
198 |
+
"data": [{"text": question, "label": predicted_class}]
|
199 |
+
}))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
200 |
|
201 |
return response
|
202 |
|
203 |
@app.get("/")
|
204 |
async def get_home():
|
205 |
+
"""
|
206 |
+
Serves a basic HTML page as the home route.
|
207 |
+
|
208 |
+
Returns:
|
209 |
+
HTMLResponse: The HTML content of the home page.
|
210 |
+
"""
|
211 |
html_code = """
|
212 |
<!DOCTYPE html>
|
213 |
<html>
|
|
|
241 |
return HTMLResponse(content=html_code)
|
242 |
|
243 |
def push_to_redis(models, tokenizers, redis_client, model_name, tokenizer_name):
|
244 |
+
"""
|
245 |
+
Saves the given models and tokenizers to Redis.
|
246 |
+
|
247 |
+
Args:
|
248 |
+
models (dict): A dictionary of model names and their corresponding
|
249 |
+
PyTorch models.
|
250 |
+
tokenizers (dict): A dictionary of tokenizer names and their
|
251 |
+
corresponding tokenizers.
|
252 |
+
redis_client: The Redis client instance.
|
253 |
+
model_name (str): The base name to use for saving the models.
|
254 |
+
tokenizer_name (str): The base name to use for saving the tokenizers.
|
255 |
+
"""
|
256 |
for model_name, model in models.items():
|
257 |
torch.save(model.state_dict(), model_name)
|
258 |
redis_client.set(f"model:{model_name}", open(model_name, "rb").read())
|
|
|
262 |
redis_client.set(f"tokenizer:{tokenizer_name}", json.dumps(tokens))
|
263 |
|
264 |
def continuous_training():
|
265 |
+
"""
|
266 |
+
Continuously checks for new training data in Redis and updates the model.
|
267 |
+
"""
|
268 |
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True)
|
269 |
|
270 |
while True:
|
271 |
try:
|
272 |
+
data = redis_client.lpop("training_queue")
|
273 |
if data:
|
274 |
data = json.loads(data)
|
275 |
unified_model = UnifiedModel.load_model_from_redis(redis_client)
|
|
|
294 |
print(f"Epoch {epoch}, Loss {loss.item()}")
|
295 |
|
296 |
push_to_redis(unified_model.models, data["tokenizers"], redis_client, "unified_model", "unified_tokenizer")
|
|
|
297 |
time.sleep(10)
|
298 |
except Exception as e:
|
299 |
print(f"Error in continuous training: {e}")
|
300 |
time.sleep(5)
|
301 |
|
302 |
if __name__ == "__main__":
|
303 |
+
# Start the continuous training process in a separate process
|
304 |
+
training_process = multiprocessing.Process(target=continuous_training)
|
305 |
+
training_process.start()
|
306 |
+
|
307 |
+
# Run the FastAPI app
|
308 |
import uvicorn
|
309 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|