Yjhhh commited on
Commit
357dacf
·
verified ·
1 Parent(s): e2d0b71

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +378 -0
main.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dotenv import load_dotenv
2
+ import os
3
+ import json
4
+ import requests
5
+ import redis
6
+ from transformers import (
7
+ AutoTokenizer,
8
+ AutoModelForSequenceClassification,
9
+ AutoModelForCausalLM,
10
+ )
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.utils.data import DataLoader, Dataset
14
+ from torch.optim import AdamW
15
+ from fastapi import FastAPI, HTTPException, Request
16
+ from fastapi.responses import HTMLResponse
17
+ import multiprocessing
18
+ import time
19
+ import uuid
20
+
21
+ load_dotenv()
22
+
23
+ REDIS_HOST = os.getenv('REDIS_HOST')
24
+ REDIS_PORT = os.getenv('REDIS_PORT')
25
+ REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
26
+
27
+ app = FastAPI()
28
+
29
+ default_language = "es"
30
+
31
+ class ChatbotService:
32
+ def __init__(self):
33
+ self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True)
34
+ self.model_name = "response_model"
35
+ self.tokenizer_name = "response_tokenizer"
36
+ self.model = self.load_model_from_redis()
37
+ self.tokenizer = self.load_tokenizer_from_redis()
38
+
39
+ def get_response(self, user_id, message, language=default_language):
40
+ if self.model is None or self.tokenizer is None:
41
+ return "El modelo aún no está listo. Por favor, inténtelo de nuevo más tarde."
42
+
43
+ input_text = f"Usuario: {message} Asistente:"
44
+ input_ids = self.tokenizer.encode(input_text, return_tensors="pt").to("cpu")
45
+
46
+ with torch.no_grad():
47
+ output = self.model.generate(input_ids=input_ids, max_length=100, num_beams=5, no_repeat_ngram_size=2, early_stopping=True)
48
+
49
+ response = self.tokenizer.decode(output[0], skip_special_tokens=True)
50
+ response = response.replace(input_text, "").strip()
51
+
52
+ return response
53
+
54
+ def load_model_from_redis(self):
55
+ model_data_bytes = self.redis_client.get(f"model:{self.model_name}")
56
+ if model_data_bytes:
57
+ model = AutoModelForCausalLM.from_pretrained("gpt2")
58
+ model.load_state_dict(torch.load(model_data_bytes))
59
+ return model
60
+ return None
61
+
62
+ def load_tokenizer_from_redis(self):
63
+ tokenizer_data_bytes = self.redis_client.get(f"tokenizer:{self.tokenizer_name}")
64
+ if tokenizer_data_bytes:
65
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
66
+ tokenizer.add_tokens(json.loads(tokenizer_data_bytes))
67
+ return tokenizer
68
+ return None
69
+
70
+ chatbot_service = ChatbotService()
71
+
72
+ class UnifiedModel(nn.Module):
73
+ def __init__(self, models):
74
+ super(UnifiedModel, self).__init__()
75
+ self.models = nn.ModuleList(models)
76
+ hidden_size = self.models[0].config.hidden_size
77
+ self.projection = nn.Linear(len(models) * 3, 768)
78
+ self.classifier = nn.Linear(hidden_size, 3)
79
+
80
+ def forward(self, input_ids, attention_mask):
81
+ hidden_states = []
82
+ for model, input_id, attn_mask in zip(self.models, input_ids, attention_mask):
83
+ outputs = model(input_ids=input_id, attention_mask=attn_mask)
84
+ hidden_states.append(outputs.logits)
85
+
86
+ concatenated_hidden_states = torch.cat(hidden_states, dim=1)
87
+ projected_features = self.projection(concatenated_hidden_states)
88
+ logits = self.classifier(projected_features)
89
+ return logits
90
+
91
+ @staticmethod
92
+ def load_model_from_redis(redis_client):
93
+ model_name = "unified_model"
94
+ model_data_bytes = redis_client.get(f"model:{model_name}")
95
+ if model_data_bytes:
96
+ model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3)
97
+ model.load_state_dict(torch.load(model_data_bytes))
98
+ else:
99
+ model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3)
100
+ return UnifiedModel([model, model])
101
+
102
+ class SyntheticDataset(Dataset):
103
+ def __init__(self, tokenizers, data):
104
+ self.tokenizers = tokenizers
105
+ self.data = data
106
+
107
+ def __len__(self):
108
+ return len(self.data)
109
+
110
+ def __getitem__(self, idx):
111
+ item = self.data[idx]
112
+ text = item['text']
113
+ label = item['label']
114
+ tokenized = {}
115
+ for name, tokenizer in self.tokenizers.items():
116
+ tokens = tokenizer(text, padding="max_length", truncation=True, max_length=128)
117
+ tokenized[f"input_ids_{name}"] = torch.tensor(tokens["input_ids"])
118
+ tokenized[f"attention_mask_{name}"] = torch.tensor(tokens["attention_mask"])
119
+ tokenized["labels"] = torch.tensor(label)
120
+ return tokenized
121
+
122
+ conversation_history = {}
123
+
124
+ @app.post("/process")
125
+ async def process(request: Request):
126
+ data = await request.json()
127
+ redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True)
128
+
129
+ tokenizers = {}
130
+ models = {}
131
+
132
+ model_name = "unified_model"
133
+ tokenizer_name = "unified_tokenizer"
134
+
135
+ model_data_bytes = redis_client.get(f"model:{model_name}")
136
+ tokenizer_data_bytes = redis_client.get(f"tokenizer:{tokenizer_name}")
137
+
138
+ if model_data_bytes:
139
+ model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3)
140
+ model.load_state_dict(torch.load(model_data_bytes))
141
+ else:
142
+ model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3)
143
+ models[model_name] = model
144
+
145
+ if tokenizer_data_bytes:
146
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
147
+ tokenizer.add_tokens(json.loads(tokenizer_data_bytes))
148
+ else:
149
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
150
+ tokenizers[tokenizer_name] = tokenizer
151
+
152
+ unified_model = UnifiedModel(list(models.values()))
153
+ unified_model.to(torch.device("cpu"))
154
+
155
+ if data.get("train"):
156
+ user_data = data.get("user_data", [])
157
+ if not user_data:
158
+ user_data = [
159
+ {"text": "Hola", "label": 1},
160
+ {"text": "Necesito ayuda", "label": 2},
161
+ {"text": "No entiendo", "label": 0}
162
+ ]
163
+
164
+ redis_client.rpush("training_queue", json.dumps({
165
+ "tokenizers": {tokenizer_name: tokenizer.get_vocab()},
166
+ "data": user_data
167
+ }))
168
+
169
+ return {"message": "Training data received. Model will be updated asynchronously."}
170
+
171
+ elif data.get("message"):
172
+ user_id = data.get("user_id")
173
+ text = data['message']
174
+ language = data.get("language", default_language)
175
+
176
+ if user_id not in conversation_history:
177
+ conversation_history[user_id] = []
178
+ conversation_history[user_id].append(text)
179
+
180
+ contextualized_text = " ".join(conversation_history[user_id][-3:])
181
+
182
+ tokenized_inputs = [tokenizers[name](contextualized_text, return_tensors="pt") for name in tokenizers.keys()]
183
+ input_ids = [tokens['input_ids'] for tokens in tokenized_inputs]
184
+ attention_mask = [tokens['attention_mask'] for tokens in tokenized_inputs]
185
+
186
+ with torch.no_grad():
187
+ logits = unified_model(input_ids=input_ids, attention_mask=attention_mask)
188
+ predicted_class = torch.argmax(logits, dim=-1).item()
189
+
190
+ response = chatbot_service.get_response(user_id, contextualized_text, language)
191
+
192
+ redis_client.rpush("training_queue", json.dumps({
193
+ "tokenizers": {tokenizer_name: tokenizer.get_vocab()},
194
+ "data": [{"text": contextualized_text, "label": predicted_class}]
195
+ }))
196
+
197
+ return {"answer": response}
198
+
199
+ else:
200
+ raise HTTPException(status_code=400, detail="Request must contain 'train' or 'message'.")
201
+
202
+ def get_chatbot_response(user_id, question, predicted_class, language):
203
+ if user_id not in conversation_history:
204
+ conversation_history[user_id] = []
205
+ conversation_history[user_id].append(question)
206
+ return chatbot_service.get_response(user_id, question, language)
207
+
208
+ @app.get("/")
209
+ async def get_home():
210
+ user_id = str(uuid.uuid4())
211
+ html_code = f"""
212
+ <!DOCTYPE html>
213
+ <html>
214
+ <head>
215
+ <meta charset="UTF-8">
216
+ <title>Chatbot</title>
217
+ <style>
218
+ body {{
219
+ font-family: 'Arial', sans-serif;
220
+ background-color: #f4f4f9;
221
+ margin: 0;
222
+ padding: 0;
223
+ display: flex;
224
+ align-items: center;
225
+ justify-content: center;
226
+ min-height: 100vh;
227
+ }}
228
+ .container {{
229
+ background-color: #fff;
230
+ border-radius: 10px;
231
+ box-shadow: 0 2px 5px rgba(0, 0, 0, 0.1);
232
+ overflow: hidden;
233
+ width: 400px;
234
+ max-width: 90%;
235
+ }}
236
+ h1 {{
237
+ color: #333;
238
+ text-align: center;
239
+ padding: 20px;
240
+ margin: 0;
241
+ background-color: #f8f9fa;
242
+ border-bottom: 1px solid #eee;
243
+ }}
244
+ #chatbox {{
245
+ height: 300px;
246
+ overflow-y: auto;
247
+ padding: 10px;
248
+ border-bottom: 1px solid #eee;
249
+ }}
250
+ .message {{
251
+ margin-bottom: 10px;
252
+ }}
253
+ .user {{
254
+ color: #007bff;
255
+ }}
256
+ .bot {{
257
+ color: #28a745;
258
+ }}
259
+ #input {{
260
+ display: flex;
261
+ padding: 10px;
262
+ }}
263
+ #input textarea {{
264
+ flex: 1;
265
+ padding: 10px;
266
+ border: 1px solid #ddd;
267
+ border-radius: 4px;
268
+ margin-right: 10px;
269
+ }}
270
+ #input button {{
271
+ padding: 10px 20px;
272
+ border: none;
273
+ border-radius: 4px;
274
+ background-color: #007bff;
275
+ color: #fff;
276
+ cursor: pointer;
277
+ }}
278
+ #input button:hover {{
279
+ background-color: #0056b3;
280
+ }}
281
+ </style>
282
+ </head>
283
+ <body>
284
+ <div class="container">
285
+ <h1>Chatbot</h1>
286
+ <div id="chatbox"></div>
287
+ <div id="input">
288
+ <textarea id="message" rows="3" placeholder="Escribe tu mensaje aquí..."></textarea>
289
+ <button id="send">Enviar</button>
290
+ </div>
291
+ </div>
292
+ <script>
293
+ const chatbox = document.getElementById('chatbox');
294
+ const messageInput = document.getElementById('message');
295
+ const sendButton = document.getElementById('send');
296
+
297
+ function appendMessage(text, sender) {{
298
+ const messageDiv = document.createElement('div');
299
+ messageDiv.classList.add('message', sender);
300
+ messageDiv.textContent = text;
301
+ chatbox.appendChild(messageDiv);
302
+ chatbox.scrollTop = chatbox.scrollHeight;
303
+ }}
304
+
305
+ async function sendMessage() {{
306
+ const message = messageInput.value;
307
+ if (!message.trim()) return;
308
+
309
+ appendMessage(message, 'user');
310
+ messageInput.value = '';
311
+
312
+ const response = await fetch('/process', {{
313
+ method: 'POST',
314
+ headers: {{
315
+ 'Content-Type': 'application/json'
316
+ }},
317
+ body: JSON.stringify({{
318
+ message: message,
319
+ user_id: '{user_id}'
320
+ }})
321
+ }});
322
+ const data = await response.json();
323
+ appendMessage(data.answer, 'bot');
324
+ }}
325
+
326
+ sendButton.addEventListener('click', sendMessage);
327
+ messageInput.addEventListener('keypress', (e) => {{
328
+ if (e.key === 'Enter' && !e.shiftKey) {{
329
+ e.preventDefault();
330
+ sendMessage();
331
+ }}
332
+ }});
333
+ </script>
334
+ </body>
335
+ </html>
336
+ """
337
+ return HTMLResponse(content=html_code)
338
+
339
+ def train_unified_model():
340
+ while True:
341
+ redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True)
342
+ training_queue = redis_client.lrange("training_queue", 0, -1)
343
+ if training_queue:
344
+ for item in training_queue:
345
+ item_data = json.loads(item)
346
+ tokenizers = {name: AutoTokenizer.from_pretrained("gpt2") for name in item_data["tokenizers"]}
347
+ data = item_data["data"]
348
+ dataset = SyntheticDataset(tokenizers, data)
349
+ dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
350
+
351
+ model = UnifiedModel([AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3)])
352
+ optimizer = AdamW(model.parameters(), lr=1e-5)
353
+ criterion = nn.CrossEntropyLoss()
354
+
355
+ for epoch in range(3):
356
+ model.train()
357
+ for batch in dataloader:
358
+ input_ids = [batch[f"input_ids_{name}"] for name in tokenizers]
359
+ attention_mask = [batch[f"attention_mask_{name}"] for name in tokenizers]
360
+ labels = batch["labels"]
361
+
362
+ optimizer.zero_grad()
363
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
364
+ loss = criterion(outputs, labels)
365
+ loss.backward()
366
+ optimizer.step()
367
+
368
+ model_data_bytes = torch.save(model.state_dict(), "model_data.pt")
369
+ redis_client.set(f"model:unified_model", model_data_bytes)
370
+
371
+ redis_client.delete("training_queue")
372
+ time.sleep(60)
373
+
374
+ if __name__ == "__main__":
375
+ training_process = multiprocessing.Process(target=train_unified_model)
376
+ training_process.start()
377
+ import uvicorn
378
+ uvicorn.run(app, host="0.0.0.0", port=8000)