Yjhhh commited on
Commit
0f5190d
·
verified ·
1 Parent(s): 0f03a5d

Create App.py

Browse files
Files changed (1) hide show
  1. App.py +428 -0
App.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dotenv import load_dotenv
2
+ import os
3
+ import json
4
+ import requests
5
+ import redis
6
+ from transformers import (
7
+ AutoTokenizer,
8
+ AutoModel,
9
+ TrainingArguments,
10
+ )
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.utils.data import DataLoader, Dataset
14
+ from torch.optim import AdamW
15
+ from fastapi import FastAPI, HTTPException, Request
16
+ from pydantic import BaseModel
17
+ from typing import List, Dict
18
+ from fastapi.responses import HTMLResponse
19
+ import multiprocessing
20
+ import time
21
+
22
+ load_dotenv()
23
+
24
+ REDIS_HOST = os.getenv('REDIS_HOST')
25
+ REDIS_PORT = os.getenv('REDIS_PORT')
26
+ REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
27
+
28
+ app = FastAPI()
29
+
30
+ class UnifiedModel(nn.Module):
31
+ def __init__(self, models):
32
+ super(UnifiedModel, self).__init__()
33
+ self.models = nn.ModuleList(models)
34
+ self.classifier = nn.Linear(sum([model.config.hidden_size for model in models]), 2)
35
+
36
+ def forward(self, input_ids, attention_mask):
37
+ hidden_states = []
38
+ for model, input_id, attn_mask in zip(self.models, input_ids, attention_mask):
39
+ outputs = model(
40
+ input_ids=input_id,
41
+ attention_mask=attn_mask
42
+ )
43
+ hidden_states.append(outputs.last_hidden_state[:, 0, :])
44
+ concatenated_hidden_states = torch.cat(hidden_states, dim=-1)
45
+ logits = self.classifier(concatenated_hidden_states)
46
+ return logits
47
+
48
+ class SyntheticDataset(Dataset):
49
+ def __init__(self, tokenizers, data):
50
+ self.tokenizers = tokenizers
51
+ self.data = data
52
+
53
+ def __len__(self):
54
+ return len(self.data)
55
+
56
+ def __getitem__(self, idx):
57
+ item = self.data[idx]
58
+ text = item['text']
59
+ label = item['label']
60
+ tokenized = {}
61
+ for name, tokenizer in self.tokenizers.items():
62
+ tokens = tokenizer(text, padding="max_length", truncation=True, max_length=128)
63
+ tokenized[f"input_ids_{name}"] = torch.tensor(tokens["input_ids"])
64
+ tokenized[f"attention_mask_{name}"] = torch.tensor(tokens["attention_mask"])
65
+ tokenized["label"] = torch.tensor(label)
66
+ return tokenized
67
+
68
+ @app.post("/process")
69
+ async def process(request: Request):
70
+ data = await request.json()
71
+ redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True)
72
+
73
+ tokenizers = {}
74
+ models = {}
75
+
76
+ model_name = "unified_model"
77
+ tokenizer_name = "unified_tokenizer"
78
+
79
+ model_data_bytes = redis_client.get(f"model:{model_name}")
80
+ tokenizer_data_bytes = redis_client.get(f"tokenizer:{tokenizer_name}")
81
+
82
+ if model_data_bytes:
83
+ model_data = json.loads(model_data_bytes)
84
+ model = AutoModel.from_pretrained("gpt2")
85
+ model.load_state_dict(torch.load(model_data))
86
+ else:
87
+ model = AutoModel.from_pretrained("gpt2")
88
+ models[model_name] = model
89
+
90
+ if tokenizer_data_bytes:
91
+ tokenizer_data = json.loads(tokenizer_data_bytes)
92
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
93
+ tokenizer.add_tokens(tokenizer_data)
94
+ else:
95
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
96
+ tokenizers[tokenizer_name] = tokenizer
97
+
98
+ unified_model = UnifiedModel(list(models.values()))
99
+ unified_model.to(torch.device("cpu"))
100
+
101
+ if data.get("train"):
102
+ user_data = data.get("user_data", [])
103
+ if not user_data:
104
+ user_data = [{"text": "Sample text for automatic training.", "label": 0}]
105
+
106
+ train_dataset = SyntheticDataset(tokenizers, user_data)
107
+ train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
108
+
109
+ training_args = TrainingArguments(
110
+ output_dir="memory",
111
+ evaluation_strategy="epoch",
112
+ learning_rate=5e-5,
113
+ per_device_train_batch_size=8,
114
+ per_device_eval_batch_size=8,
115
+ num_train_epochs=10,
116
+ weight_decay=0.01,
117
+ logging_steps=10,
118
+ optim="adamw_hf"
119
+ )
120
+
121
+ optimizer = AdamW(unified_model.parameters(), lr=training_args.learning_rate)
122
+ unified_model.train()
123
+
124
+ for epoch in range(training_args.num_train_epochs):
125
+ for batch in train_loader:
126
+ input_ids = [batch[f"input_ids_{name}"].to("cpu") for name in tokenizers.keys()]
127
+ attention_mask = [batch[f"attention_mask_{name}"].to("cpu") for name in tokenizers.keys()]
128
+ labels = batch["label"].to("cpu")
129
+ outputs = unified_model(input_ids=input_ids, attention_mask=attention_mask)
130
+ loss = nn.CrossEntropyLoss()(outputs, labels)
131
+ loss.backward()
132
+ optimizer.step()
133
+ optimizer.zero_grad()
134
+
135
+ print(f"Epoch {epoch}, Loss {loss.item()}")
136
+
137
+ print("Training complete.")
138
+
139
+ push_to_redis(models, tokenizers, redis_client, model_name, tokenizer_name)
140
+ return {"message": "Model trained and updated in Redis."}
141
+
142
+ elif data.get("predict"):
143
+ text = data['text']
144
+ tokenized_inputs = [tokenizers[name](text, return_tensors="pt") for name in tokenizers.keys()]
145
+ input_ids = [tokens['input_ids'] for tokens in tokenized_inputs]
146
+ attention_mask = [tokens['attention_mask'] for tokens in tokenized_inputs]
147
+
148
+ with torch.no_grad():
149
+ logits = unified_model(input_ids=input_ids, attention_mask=attention_mask)
150
+ predicted_class = torch.argmax(logits, dim=-1).item()
151
+
152
+ return {"prediction": predicted_class}
153
+
154
+ else:
155
+ raise HTTPException(status_code=400, detail="Request must contain 'train' or 'predict'.")
156
+
157
+ @app.post("/external_answer")
158
+ async def external_answer(request: Request):
159
+ data = await request.json()
160
+ redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True)
161
+
162
+ question = data.get('question')
163
+ if not question:
164
+ raise HTTPException(status_code=400, detail="Question is required.")
165
+
166
+ model_name = "unified_model"
167
+ tokenizer_name = "unified_tokenizer"
168
+
169
+ model_data_bytes = redis_client.get(f"model:{model_name}")
170
+ tokenizer_data_bytes = redis_client.get(f"tokenizer:{tokenizer_name}")
171
+
172
+ if model_data_bytes:
173
+ model_data = json.loads(model_data_bytes)
174
+ model = AutoModel.from_pretrained("gpt2")
175
+ model.load_state_dict(torch.load(model_data))
176
+ else:
177
+ model = AutoModel.from_pretrained("gpt2")
178
+
179
+ if tokenizer_data_bytes:
180
+ tokenizer_data = json.loads(tokenizer_data_bytes)
181
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
182
+ tokenizer.add_tokens(tokenizer_data)
183
+ else:
184
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
185
+
186
+ unified_model = UnifiedModel([model])
187
+ unified_model.to(torch.device("cpu"))
188
+
189
+ tokenized_input = tokenizer(question, return_tensors="pt")
190
+ input_ids = tokenized_input['input_ids']
191
+ attention_mask = tokenized_input['attention_mask']
192
+
193
+ with torch.no_grad():
194
+ logits = unified_model(input_ids=input_ids, attention_mask=attention_mask)
195
+ predicted_class = torch.argmax(logits, dim=-1).item()
196
+ response = {"answer": f"Response to '{question}' is class {predicted_class}"}
197
+
198
+ extreme_training_data = [{"text": question, "label": predicted_class}]
199
+ train_dataset = SyntheticDataset({tokenizer_name: tokenizer}, extreme_training_data)
200
+ train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
201
+
202
+ training_args = TrainingArguments(
203
+ output_dir="memory",
204
+ evaluation_strategy="epoch",
205
+ learning_rate=5e-5,
206
+ per_device_train_batch_size=8,
207
+ per_device_eval_batch_size=8,
208
+ num_train_epochs=10,
209
+ weight_decay=0.01,
210
+ logging_steps=10,
211
+ optim="adamw_hf"
212
+ )
213
+
214
+ optimizer = AdamW(unified_model.parameters(), lr=training_args.learning_rate)
215
+ unified_model.train()
216
+
217
+ for epoch in range(training_args.num_train_epochs):
218
+ for batch in train_loader:
219
+ input_ids = [batch[f"input_ids_{name}"].to("cpu") for name in [tokenizer_name]]
220
+ attention_mask = [batch[f"attention_mask_{name}"].to("cpu") for name in [tokenizer_name]]
221
+ labels = batch["label"].to("cpu")
222
+ outputs = unified_model(input_ids=input_ids, attention_mask=attention_mask)
223
+ loss = nn.CrossEntropyLoss()(outputs, labels)
224
+ loss.backward()
225
+ optimizer.step()
226
+ optimizer.zero_grad()
227
+
228
+ print(f"Epoch {epoch}, Loss {loss.item()}")
229
+
230
+ print("Extreme training complete.")
231
+ push_to_redis({model_name: model}, {tokenizer_name: tokenizer}, redis_client, model_name, tokenizer_name)
232
+
233
+ return response
234
+
235
+ @app.get("/")
236
+ async def get_home():
237
+ html_code = """
238
+ <!DOCTYPE html>
239
+ <html>
240
+ <head>
241
+ <meta charset="UTF-8">
242
+ <title>Chatbot</title>
243
+ <style>
244
+ body {
245
+ font-family: Arial, sans-serif;
246
+ background-color: #f4f4f9;
247
+ margin: 0;
248
+ padding: 0;
249
+ }
250
+ .container {
251
+ max-width: 1200px;
252
+ margin: 0 auto;
253
+ padding: 20px;
254
+ }
255
+ h1 {
256
+ color: #333;
257
+ text-align: center;
258
+ }
259
+ .grid-container {
260
+ display: grid;
261
+ grid-template-columns: repeat(auto-fill, minmax(300px, 1fr));
262
+ gap: 10px;
263
+ margin-top: 20px;
264
+ }
265
+ .grid-item {
266
+ background: #fff;
267
+ padding: 20px;
268
+ border-radius: 8px;
269
+ box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
270
+ transition: transform 0.3s;
271
+ }
272
+ .grid-item:hover {
273
+ transform: scale(1.05);
274
+ }
275
+ .question {
276
+ font-weight: bold;
277
+ color: #007bff;
278
+ }
279
+ .answer {
280
+ margin-top: 10px;
281
+ color: #333;
282
+ }
283
+ input[type="text"] {
284
+ width: calc(100% - 22px);
285
+ padding: 10px;
286
+ margin: 0;
287
+ border: 1px solid #ddd;
288
+ border-radius: 4px;
289
+ }
290
+ button {
291
+ padding: 10px 20px;
292
+ background-color: #007bff;
293
+ color: #fff;
294
+ border: none;
295
+ border-radius: 4px;
296
+ cursor: pointer;
297
+ margin-top: 10px;
298
+ }
299
+ button:hover {
300
+ background-color: #0056b3;
301
+ }
302
+ </style>
303
+ <script>
304
+ async function sendMessage() {
305
+ const question = document.getElementById('question').value;
306
+ const responseElement = document.getElementById('response');
307
+
308
+ const response = await fetch('/external_answer', {
309
+ method: 'POST',
310
+ headers: {
311
+ 'Content-Type': 'application/json',
312
+ },
313
+ body: JSON.stringify({ question: question })
314
+ });
315
+
316
+ const data = await response.json();
317
+ responseElement.innerText = "Response: " + data.answer;
318
+
319
+ const gridContainer = document.getElementById('grid-container');
320
+ const newItem = document.createElement('div');
321
+ newItem.classList.add('grid-item');
322
+ newItem.innerHTML = `<div class="question">${question}</div><div class="answer">${data.answer}</div>`;
323
+ gridContainer.prepend(newItem);
324
+ }
325
+ </script>
326
+ </head>
327
+ <body>
328
+ <div class="container">
329
+ <h1>Chatbot</h1>
330
+ <input type="text" id="question" placeholder="Ask me something...">
331
+ <button onclick="sendMessage()">Send</button>
332
+ <div id="response"></div>
333
+ <div class="grid-container" id="grid-container"></div>
334
+ </div>
335
+ </body>
336
+ </html>
337
+ """
338
+ return HTMLResponse(content=html_code)
339
+
340
+ def push_to_redis(models, tokenizers, redis_client, model_name, tokenizer_name):
341
+ model_data = json.dumps(next(iter(models.values())).state_dict())
342
+ redis_client.set(f"model:{model_name}", model_data)
343
+
344
+ tokenizer_data = json.dumps(next(iter(tokenizers.values())).get_vocab())
345
+ redis_client.set(f"tokenizer:{tokenizer_name}", tokenizer_data)
346
+
347
+ def continuous_training():
348
+ redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True)
349
+
350
+ model_name = "unified_model"
351
+ tokenizer_name = "unified_tokenizer"
352
+
353
+ while True:
354
+ try:
355
+ model_data_bytes = redis_client.get(f"model:{model_name}")
356
+ tokenizer_data_bytes = redis_client.get(f"tokenizer:{tokenizer_name}")
357
+
358
+ if model_data_bytes and tokenizer_data_bytes:
359
+ model_data = json.loads(model_data_bytes)
360
+ model = AutoModel.from_pretrained("gpt2")
361
+ model.load_state_dict(torch.load(model_data))
362
+
363
+ tokenizer_data = json.loads(tokenizer_data_bytes)
364
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
365
+ tokenizer.add_tokens(tokenizer_data)
366
+
367
+ unified_model = UnifiedModel([model])
368
+ unified_model.to(torch.device("cpu"))
369
+
370
+ train_data = [{"text": "Sample training text.", "label": 0}]
371
+ train_dataset = SyntheticDataset({tokenizer_name: tokenizer}, train_data)
372
+ train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
373
+
374
+ training_args = TrainingArguments(
375
+ output_dir="memory",
376
+ evaluation_strategy="epoch",
377
+ learning_rate=5e-5,
378
+ per_device_train_batch_size=8,
379
+ per_device_eval_batch_size=8,
380
+ num_train_epochs=10,
381
+ weight_decay=0.01,
382
+ logging_steps=10,
383
+ optim="adamw_hf"
384
+ )
385
+
386
+ optimizer = AdamW(unified_model.parameters(), lr=training_args.learning_rate)
387
+ unified_model.train()
388
+
389
+ for epoch in range(training_args.num_train_epochs):
390
+ for batch in train_loader:
391
+ input_ids = [batch[f"input_ids_{name}"].to("cpu") for name in [tokenizer_name]]
392
+ attention_mask = [batch[f"attention_mask_{name}"].to("cpu") for name in [tokenizer_name]]
393
+ labels = batch["label"].to("cpu")
394
+ outputs = unified_model(input_ids=input_ids, attention_mask=attention_mask)
395
+ loss = nn.CrossEntropyLoss()(outputs, labels)
396
+ loss.backward()
397
+ optimizer.step()
398
+ optimizer.zero_grad()
399
+
400
+ print(f"Epoch {epoch}, Loss {loss.item()}")
401
+
402
+ print("Training complete.")
403
+ push_to_redis({model_name: model}, {tokenizer_name: tokenizer}, redis_client, model_name, tokenizer_name)
404
+ else:
405
+ print("No model or tokenizer found in Redis. Skipping training.")
406
+
407
+ time.sleep(600)
408
+
409
+ except Exception as e:
410
+ print(f"An error occurred: {e}")
411
+ time.sleep(60)
412
+
413
+ def start_server():
414
+ import uvicorn
415
+ cpu_cores = os.cpu_count() or 1
416
+ num_workers = max(1, cpu_cores - 1)
417
+
418
+ uvicorn.run(app, host="0.0.0.0", port=8000, workers=num_workers, timeout_keep_alive=0)
419
+
420
+ if __name__ == "__main__":
421
+ api_process = multiprocessing.Process(target=start_server)
422
+ training_process = multiprocessing.Process(target=continuous_training)
423
+
424
+ api_process.start()
425
+ training_process.start()
426
+
427
+ api_process.join()
428
+ training_process.join()