Update main.py
Browse files
main.py
CHANGED
@@ -6,15 +6,12 @@ import redis
|
|
6 |
from transformers import (
|
7 |
AutoTokenizer,
|
8 |
AutoModelForSequenceClassification,
|
9 |
-
TrainingArguments,
|
10 |
)
|
11 |
import torch
|
12 |
import torch.nn as nn
|
13 |
from torch.utils.data import DataLoader, Dataset
|
14 |
from torch.optim import AdamW
|
15 |
from fastapi import FastAPI, HTTPException, Request
|
16 |
-
from pydantic import BaseModel
|
17 |
-
from typing import List, Dict
|
18 |
from fastapi.responses import HTMLResponse
|
19 |
import multiprocessing
|
20 |
import time
|
@@ -29,75 +26,21 @@ REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
|
|
29 |
|
30 |
app = FastAPI()
|
31 |
|
32 |
-
|
33 |
-
|
34 |
-
"es": {
|
35 |
-
0: [
|
36 |
-
"Lo siento, no entiendo.",
|
37 |
-
"No estoy seguro de entender lo que quieres decir.",
|
38 |
-
"¿Podrías reformular tu pregunta?"
|
39 |
-
],
|
40 |
-
1: [
|
41 |
-
"Hola! ¿Cómo estás?",
|
42 |
-
"¡Hola! ¿Qué tal?",
|
43 |
-
"Buenos días/tardes/noches, ¿cómo te va?"
|
44 |
-
],
|
45 |
-
2: [
|
46 |
-
"¿Cómo te puedo ayudar?",
|
47 |
-
"¿En qué puedo ayudarte?",
|
48 |
-
"Dime, ¿qué necesitas?"
|
49 |
-
],
|
50 |
-
# ... más respuestas para otras clases
|
51 |
-
},
|
52 |
-
"en": {
|
53 |
-
0: [
|
54 |
-
"Sorry, I don't understand.",
|
55 |
-
"I'm not sure I understand what you mean.",
|
56 |
-
"Could you rephrase your question?"
|
57 |
-
],
|
58 |
-
1: [
|
59 |
-
"Hello! How are you?",
|
60 |
-
"Hi! What's up?",
|
61 |
-
"Good morning/afternoon/evening, how are you doing?"
|
62 |
-
],
|
63 |
-
2: [
|
64 |
-
"How can I help you?",
|
65 |
-
"What can I do for you?",
|
66 |
-
"Tell me, what do you need?"
|
67 |
-
],
|
68 |
-
# ... más respuestas para otras clases
|
69 |
-
}
|
70 |
-
}
|
71 |
-
|
72 |
-
default_language = "es" # Idioma predeterminado
|
73 |
-
|
74 |
-
# Servicio de Chatbot
|
75 |
class ChatbotService:
|
76 |
def get_response(self, user_id, message, predicted_class, language=default_language):
|
77 |
-
|
78 |
-
responses = language_responses.get(language, language_responses["es"])
|
79 |
-
|
80 |
-
# Lógica basada en la clase predicha
|
81 |
-
if predicted_class == 1:
|
82 |
-
# Seleccionar una respuesta de saludo aleatoria
|
83 |
-
return random.choice(responses[1])
|
84 |
-
elif predicted_class == 2:
|
85 |
-
# Seleccionar una respuesta de ayuda aleatoria
|
86 |
-
return random.choice(responses[2])
|
87 |
-
else:
|
88 |
-
# Seleccionar una respuesta de no comprensión aleatoria
|
89 |
-
return random.choice(responses[0])
|
90 |
|
91 |
chatbot_service = ChatbotService()
|
92 |
|
93 |
-
# Modelo de clasificación de texto
|
94 |
class UnifiedModel(nn.Module):
|
95 |
def __init__(self, models):
|
96 |
super(UnifiedModel, self).__init__()
|
97 |
self.models = nn.ModuleList(models)
|
98 |
hidden_size = self.models[0].config.hidden_size
|
99 |
-
self.projection = nn.Linear(len(models) * 3, 768)
|
100 |
-
self.classifier = nn.Linear(hidden_size, 3)
|
101 |
|
102 |
def forward(self, input_ids, attention_mask):
|
103 |
hidden_states = []
|
@@ -106,11 +49,11 @@ class UnifiedModel(nn.Module):
|
|
106 |
input_ids=input_id,
|
107 |
attention_mask=attn_mask
|
108 |
)
|
109 |
-
hidden_states.append(outputs.logits)
|
110 |
|
111 |
concatenated_hidden_states = torch.cat(hidden_states, dim=1)
|
112 |
-
projected_features = self.projection(concatenated_hidden_states)
|
113 |
-
logits = self.classifier(projected_features)
|
114 |
return logits
|
115 |
|
116 |
@staticmethod
|
@@ -118,14 +61,12 @@ class UnifiedModel(nn.Module):
|
|
118 |
model_name = "unified_model"
|
119 |
model_data_bytes = redis_client.get(f"model:{model_name}")
|
120 |
if model_data_bytes:
|
121 |
-
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3)
|
122 |
model.load_state_dict(torch.load(model_data_bytes))
|
123 |
else:
|
124 |
-
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3)
|
125 |
-
|
126 |
-
return UnifiedModel([model, model]) # Asegurar que se usa una lista de modelos, en este caso 2
|
127 |
|
128 |
-
# Dataset para entrenamiento
|
129 |
class SyntheticDataset(Dataset):
|
130 |
def __init__(self, tokenizers, data):
|
131 |
self.tokenizers = tokenizers
|
@@ -146,14 +87,13 @@ class SyntheticDataset(Dataset):
|
|
146 |
tokenized["labels"] = torch.tensor(label)
|
147 |
return tokenized
|
148 |
|
149 |
-
# Manejo de la conversación
|
150 |
conversation_history = {}
|
151 |
|
152 |
@app.post("/process")
|
153 |
async def process(request: Request):
|
154 |
data = await request.json()
|
155 |
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True)
|
156 |
-
|
157 |
tokenizers = {}
|
158 |
models = {}
|
159 |
|
@@ -164,10 +104,10 @@ async def process(request: Request):
|
|
164 |
tokenizer_data_bytes = redis_client.get(f"tokenizer:{tokenizer_name}")
|
165 |
|
166 |
if model_data_bytes:
|
167 |
-
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3)
|
168 |
model.load_state_dict(torch.load(model_data_bytes))
|
169 |
else:
|
170 |
-
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3)
|
171 |
models[model_name] = model
|
172 |
|
173 |
if tokenizer_data_bytes:
|
@@ -179,53 +119,49 @@ async def process(request: Request):
|
|
179 |
|
180 |
unified_model = UnifiedModel(list(models.values()))
|
181 |
unified_model.to(torch.device("cpu"))
|
182 |
-
|
183 |
if data.get("train"):
|
184 |
user_data = data.get("user_data", [])
|
185 |
if not user_data:
|
186 |
user_data = [
|
187 |
-
{"text": "Hola", "label": 1},
|
188 |
{"text": "Necesito ayuda", "label": 2},
|
189 |
{"text": "No entiendo", "label": 0}
|
190 |
-
# ... más ejemplos para otras clases
|
191 |
]
|
192 |
-
|
193 |
redis_client.rpush("training_queue", json.dumps({
|
194 |
"tokenizers": {tokenizer_name: tokenizer.get_vocab()},
|
195 |
"data": user_data
|
196 |
}))
|
197 |
|
198 |
return {"message": "Training data received. Model will be updated asynchronously."}
|
199 |
-
|
200 |
elif data.get("message"):
|
201 |
user_id = data.get("user_id")
|
202 |
text = data['message']
|
203 |
language = data.get("language", default_language)
|
204 |
|
205 |
-
# Memoria de Conversación
|
206 |
if user_id not in conversation_history:
|
207 |
conversation_history[user_id] = []
|
208 |
conversation_history[user_id].append(text)
|
209 |
|
210 |
-
|
211 |
-
contextualized_text = " ".join(conversation_history[user_id][-3:]) # Usar los últimos 3 mensajes
|
212 |
|
213 |
tokenized_inputs = [tokenizers[name](contextualized_text, return_tensors="pt") for name in tokenizers.keys()]
|
214 |
input_ids = [tokens['input_ids'] for tokens in tokenized_inputs]
|
215 |
attention_mask = [tokens['attention_mask'] for tokens in tokenized_inputs]
|
216 |
-
|
217 |
with torch.no_grad():
|
218 |
logits = unified_model(input_ids=input_ids, attention_mask=attention_mask)
|
219 |
predicted_class = torch.argmax(logits, dim=-1).item()
|
220 |
-
|
221 |
response = get_chatbot_response(user_id, text, predicted_class, language)
|
222 |
return {"answer": response}
|
223 |
-
|
224 |
else:
|
225 |
raise HTTPException(status_code=400, detail="Request must contain 'train' or 'message'.")
|
226 |
|
227 |
def get_chatbot_response(user_id, question, predicted_class, language):
|
228 |
-
# Almacenar el mensaje en el historial
|
229 |
if user_id not in conversation_history:
|
230 |
conversation_history[user_id] = []
|
231 |
conversation_history[user_id].append(question)
|
@@ -401,7 +337,7 @@ def push_to_redis(models, tokenizers, redis_client, model_name, tokenizer_name):
|
|
401 |
|
402 |
def continuous_training():
|
403 |
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True)
|
404 |
-
|
405 |
while True:
|
406 |
try:
|
407 |
data = redis_client.lpop("training_queue")
|
@@ -409,12 +345,12 @@ def continuous_training():
|
|
409 |
data = json.loads(data)
|
410 |
unified_model = UnifiedModel.load_model_from_redis(redis_client)
|
411 |
unified_model.train()
|
412 |
-
|
413 |
train_dataset = SyntheticDataset(data["tokenizers"], data["data"])
|
414 |
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
|
415 |
-
|
416 |
optimizer = AdamW(unified_model.parameters(), lr=5e-5)
|
417 |
-
|
418 |
for epoch in range(10):
|
419 |
for batch in train_loader:
|
420 |
input_ids = [batch[f"input_ids_{name}"].to("cpu") for name in data["tokenizers"].keys()]
|
@@ -425,9 +361,9 @@ def continuous_training():
|
|
425 |
loss.backward()
|
426 |
optimizer.step()
|
427 |
optimizer.zero_grad()
|
428 |
-
|
429 |
print(f"Epoch {epoch}, Loss {loss.item()}")
|
430 |
-
|
431 |
push_to_redis(unified_model.models, data["tokenizers"], redis_client, "unified_model", "unified_tokenizer")
|
432 |
time.sleep(10)
|
433 |
except Exception as e:
|
|
|
6 |
from transformers import (
|
7 |
AutoTokenizer,
|
8 |
AutoModelForSequenceClassification,
|
|
|
9 |
)
|
10 |
import torch
|
11 |
import torch.nn as nn
|
12 |
from torch.utils.data import DataLoader, Dataset
|
13 |
from torch.optim import AdamW
|
14 |
from fastapi import FastAPI, HTTPException, Request
|
|
|
|
|
15 |
from fastapi.responses import HTMLResponse
|
16 |
import multiprocessing
|
17 |
import time
|
|
|
26 |
|
27 |
app = FastAPI()
|
28 |
|
29 |
+
default_language = "es"
|
30 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
class ChatbotService:
|
32 |
def get_response(self, user_id, message, predicted_class, language=default_language):
|
33 |
+
return "Respuesta por defecto."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
chatbot_service = ChatbotService()
|
36 |
|
|
|
37 |
class UnifiedModel(nn.Module):
|
38 |
def __init__(self, models):
|
39 |
super(UnifiedModel, self).__init__()
|
40 |
self.models = nn.ModuleList(models)
|
41 |
hidden_size = self.models[0].config.hidden_size
|
42 |
+
self.projection = nn.Linear(len(models) * 3, 768)
|
43 |
+
self.classifier = nn.Linear(hidden_size, 3)
|
44 |
|
45 |
def forward(self, input_ids, attention_mask):
|
46 |
hidden_states = []
|
|
|
49 |
input_ids=input_id,
|
50 |
attention_mask=attn_mask
|
51 |
)
|
52 |
+
hidden_states.append(outputs.logits)
|
53 |
|
54 |
concatenated_hidden_states = torch.cat(hidden_states, dim=1)
|
55 |
+
projected_features = self.projection(concatenated_hidden_states)
|
56 |
+
logits = self.classifier(projected_features)
|
57 |
return logits
|
58 |
|
59 |
@staticmethod
|
|
|
61 |
model_name = "unified_model"
|
62 |
model_data_bytes = redis_client.get(f"model:{model_name}")
|
63 |
if model_data_bytes:
|
64 |
+
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3)
|
65 |
model.load_state_dict(torch.load(model_data_bytes))
|
66 |
else:
|
67 |
+
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3)
|
68 |
+
return UnifiedModel([model, model])
|
|
|
69 |
|
|
|
70 |
class SyntheticDataset(Dataset):
|
71 |
def __init__(self, tokenizers, data):
|
72 |
self.tokenizers = tokenizers
|
|
|
87 |
tokenized["labels"] = torch.tensor(label)
|
88 |
return tokenized
|
89 |
|
|
|
90 |
conversation_history = {}
|
91 |
|
92 |
@app.post("/process")
|
93 |
async def process(request: Request):
|
94 |
data = await request.json()
|
95 |
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True)
|
96 |
+
|
97 |
tokenizers = {}
|
98 |
models = {}
|
99 |
|
|
|
104 |
tokenizer_data_bytes = redis_client.get(f"tokenizer:{tokenizer_name}")
|
105 |
|
106 |
if model_data_bytes:
|
107 |
+
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3)
|
108 |
model.load_state_dict(torch.load(model_data_bytes))
|
109 |
else:
|
110 |
+
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3)
|
111 |
models[model_name] = model
|
112 |
|
113 |
if tokenizer_data_bytes:
|
|
|
119 |
|
120 |
unified_model = UnifiedModel(list(models.values()))
|
121 |
unified_model.to(torch.device("cpu"))
|
122 |
+
|
123 |
if data.get("train"):
|
124 |
user_data = data.get("user_data", [])
|
125 |
if not user_data:
|
126 |
user_data = [
|
127 |
+
{"text": "Hola", "label": 1},
|
128 |
{"text": "Necesito ayuda", "label": 2},
|
129 |
{"text": "No entiendo", "label": 0}
|
|
|
130 |
]
|
131 |
+
|
132 |
redis_client.rpush("training_queue", json.dumps({
|
133 |
"tokenizers": {tokenizer_name: tokenizer.get_vocab()},
|
134 |
"data": user_data
|
135 |
}))
|
136 |
|
137 |
return {"message": "Training data received. Model will be updated asynchronously."}
|
138 |
+
|
139 |
elif data.get("message"):
|
140 |
user_id = data.get("user_id")
|
141 |
text = data['message']
|
142 |
language = data.get("language", default_language)
|
143 |
|
|
|
144 |
if user_id not in conversation_history:
|
145 |
conversation_history[user_id] = []
|
146 |
conversation_history[user_id].append(text)
|
147 |
|
148 |
+
contextualized_text = " ".join(conversation_history[user_id][-3:])
|
|
|
149 |
|
150 |
tokenized_inputs = [tokenizers[name](contextualized_text, return_tensors="pt") for name in tokenizers.keys()]
|
151 |
input_ids = [tokens['input_ids'] for tokens in tokenized_inputs]
|
152 |
attention_mask = [tokens['attention_mask'] for tokens in tokenized_inputs]
|
153 |
+
|
154 |
with torch.no_grad():
|
155 |
logits = unified_model(input_ids=input_ids, attention_mask=attention_mask)
|
156 |
predicted_class = torch.argmax(logits, dim=-1).item()
|
157 |
+
|
158 |
response = get_chatbot_response(user_id, text, predicted_class, language)
|
159 |
return {"answer": response}
|
160 |
+
|
161 |
else:
|
162 |
raise HTTPException(status_code=400, detail="Request must contain 'train' or 'message'.")
|
163 |
|
164 |
def get_chatbot_response(user_id, question, predicted_class, language):
|
|
|
165 |
if user_id not in conversation_history:
|
166 |
conversation_history[user_id] = []
|
167 |
conversation_history[user_id].append(question)
|
|
|
337 |
|
338 |
def continuous_training():
|
339 |
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True)
|
340 |
+
|
341 |
while True:
|
342 |
try:
|
343 |
data = redis_client.lpop("training_queue")
|
|
|
345 |
data = json.loads(data)
|
346 |
unified_model = UnifiedModel.load_model_from_redis(redis_client)
|
347 |
unified_model.train()
|
348 |
+
|
349 |
train_dataset = SyntheticDataset(data["tokenizers"], data["data"])
|
350 |
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
|
351 |
+
|
352 |
optimizer = AdamW(unified_model.parameters(), lr=5e-5)
|
353 |
+
|
354 |
for epoch in range(10):
|
355 |
for batch in train_loader:
|
356 |
input_ids = [batch[f"input_ids_{name}"].to("cpu") for name in data["tokenizers"].keys()]
|
|
|
361 |
loss.backward()
|
362 |
optimizer.step()
|
363 |
optimizer.zero_grad()
|
364 |
+
|
365 |
print(f"Epoch {epoch}, Loss {loss.item()}")
|
366 |
+
|
367 |
push_to_redis(unified_model.models, data["tokenizers"], redis_client, "unified_model", "unified_tokenizer")
|
368 |
time.sleep(10)
|
369 |
except Exception as e:
|