Update main.py
Browse files
main.py
CHANGED
@@ -39,16 +39,12 @@ class ChatbotService:
|
|
39 |
def get_response(self, user_id, message, language=default_language):
|
40 |
if self.model is None or self.tokenizer is None:
|
41 |
return "El modelo aún no está listo. Por favor, inténtelo de nuevo más tarde."
|
42 |
-
|
43 |
input_text = f"Usuario: {message} Asistente:"
|
44 |
input_ids = self.tokenizer.encode(input_text, return_tensors="pt").to("cpu")
|
45 |
-
|
46 |
with torch.no_grad():
|
47 |
output = self.model.generate(input_ids=input_ids, max_length=100, num_beams=5, no_repeat_ngram_size=2, early_stopping=True)
|
48 |
-
|
49 |
response = self.tokenizer.decode(output[0], skip_special_tokens=True)
|
50 |
response = response.replace(input_text, "").strip()
|
51 |
-
|
52 |
return response
|
53 |
|
54 |
def load_model_from_redis(self):
|
@@ -64,6 +60,7 @@ class ChatbotService:
|
|
64 |
if tokenizer_data_bytes:
|
65 |
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
66 |
tokenizer.add_tokens(json.loads(tokenizer_data_bytes))
|
|
|
67 |
return tokenizer
|
68 |
return None
|
69 |
|
@@ -82,7 +79,6 @@ class UnifiedModel(nn.Module):
|
|
82 |
for model, input_id, attn_mask in zip(self.models, input_ids, attention_mask):
|
83 |
outputs = model(input_ids=input_id, attention_mask=attn_mask)
|
84 |
hidden_states.append(outputs.logits)
|
85 |
-
|
86 |
concatenated_hidden_states = torch.cat(hidden_states, dim=1)
|
87 |
projected_features = self.projection(concatenated_hidden_states)
|
88 |
logits = self.classifier(projected_features)
|
@@ -145,8 +141,10 @@ async def process(request: Request):
|
|
145 |
if tokenizer_data_bytes:
|
146 |
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
147 |
tokenizer.add_tokens(json.loads(tokenizer_data_bytes))
|
|
|
148 |
else:
|
149 |
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
|
|
150 |
tokenizers[tokenizer_name] = tokenizer
|
151 |
|
152 |
unified_model = UnifiedModel(list(models.values()))
|
@@ -160,42 +158,31 @@ async def process(request: Request):
|
|
160 |
{"text": "Necesito ayuda", "label": 2},
|
161 |
{"text": "No entiendo", "label": 0}
|
162 |
]
|
163 |
-
|
164 |
redis_client.rpush("training_queue", json.dumps({
|
165 |
"tokenizers": {tokenizer_name: tokenizer.get_vocab()},
|
166 |
"data": user_data
|
167 |
}))
|
168 |
-
|
169 |
return {"message": "Training data received. Model will be updated asynchronously."}
|
170 |
-
|
171 |
elif data.get("message"):
|
172 |
user_id = data.get("user_id")
|
173 |
text = data['message']
|
174 |
language = data.get("language", default_language)
|
175 |
-
|
176 |
if user_id not in conversation_history:
|
177 |
conversation_history[user_id] = []
|
178 |
conversation_history[user_id].append(text)
|
179 |
-
|
180 |
contextualized_text = " ".join(conversation_history[user_id][-3:])
|
181 |
-
|
182 |
tokenized_inputs = [tokenizers[name](contextualized_text, return_tensors="pt") for name in tokenizers.keys()]
|
183 |
input_ids = [tokens['input_ids'] for tokens in tokenized_inputs]
|
184 |
attention_mask = [tokens['attention_mask'] for tokens in tokenized_inputs]
|
185 |
-
|
186 |
with torch.no_grad():
|
187 |
logits = unified_model(input_ids=input_ids, attention_mask=attention_mask)
|
188 |
predicted_class = torch.argmax(logits, dim=-1).item()
|
189 |
-
|
190 |
response = chatbot_service.get_response(user_id, contextualized_text, language)
|
191 |
-
|
192 |
redis_client.rpush("training_queue", json.dumps({
|
193 |
"tokenizers": {tokenizer_name: tokenizer.get_vocab()},
|
194 |
"data": [{"text": contextualized_text, "label": predicted_class}]
|
195 |
}))
|
196 |
-
|
197 |
return {"answer": response}
|
198 |
-
|
199 |
else:
|
200 |
raise HTTPException(status_code=400, detail="Request must contain 'train' or 'message'.")
|
201 |
|
@@ -249,12 +236,15 @@ async def get_home():
|
|
249 |
}}
|
250 |
.message {{
|
251 |
margin-bottom: 10px;
|
|
|
|
|
252 |
}}
|
253 |
-
.user {{
|
254 |
-
color: #
|
|
|
255 |
}}
|
256 |
-
.bot {{
|
257 |
-
color: #
|
258 |
}}
|
259 |
#input {{
|
260 |
display: flex;
|
@@ -337,13 +327,15 @@ async def get_home():
|
|
337 |
return HTMLResponse(content=html_code)
|
338 |
|
339 |
def train_unified_model():
|
|
|
340 |
while True:
|
341 |
-
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True)
|
342 |
training_queue = redis_client.lrange("training_queue", 0, -1)
|
343 |
if training_queue:
|
344 |
for item in training_queue:
|
345 |
item_data = json.loads(item)
|
346 |
tokenizers = {name: AutoTokenizer.from_pretrained("gpt2") for name in item_data["tokenizers"]}
|
|
|
|
|
347 |
data = item_data["data"]
|
348 |
dataset = SyntheticDataset(tokenizers, data)
|
349 |
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
|
@@ -375,4 +367,4 @@ if __name__ == "__main__":
|
|
375 |
training_process = multiprocessing.Process(target=train_unified_model)
|
376 |
training_process.start()
|
377 |
import uvicorn
|
378 |
-
uvicorn.run(app, host="0.0.0.0", port=
|
|
|
39 |
def get_response(self, user_id, message, language=default_language):
|
40 |
if self.model is None or self.tokenizer is None:
|
41 |
return "El modelo aún no está listo. Por favor, inténtelo de nuevo más tarde."
|
|
|
42 |
input_text = f"Usuario: {message} Asistente:"
|
43 |
input_ids = self.tokenizer.encode(input_text, return_tensors="pt").to("cpu")
|
|
|
44 |
with torch.no_grad():
|
45 |
output = self.model.generate(input_ids=input_ids, max_length=100, num_beams=5, no_repeat_ngram_size=2, early_stopping=True)
|
|
|
46 |
response = self.tokenizer.decode(output[0], skip_special_tokens=True)
|
47 |
response = response.replace(input_text, "").strip()
|
|
|
48 |
return response
|
49 |
|
50 |
def load_model_from_redis(self):
|
|
|
60 |
if tokenizer_data_bytes:
|
61 |
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
62 |
tokenizer.add_tokens(json.loads(tokenizer_data_bytes))
|
63 |
+
tokenizer.pad_token = tokenizer.eos_token
|
64 |
return tokenizer
|
65 |
return None
|
66 |
|
|
|
79 |
for model, input_id, attn_mask in zip(self.models, input_ids, attention_mask):
|
80 |
outputs = model(input_ids=input_id, attention_mask=attn_mask)
|
81 |
hidden_states.append(outputs.logits)
|
|
|
82 |
concatenated_hidden_states = torch.cat(hidden_states, dim=1)
|
83 |
projected_features = self.projection(concatenated_hidden_states)
|
84 |
logits = self.classifier(projected_features)
|
|
|
141 |
if tokenizer_data_bytes:
|
142 |
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
143 |
tokenizer.add_tokens(json.loads(tokenizer_data_bytes))
|
144 |
+
tokenizer.pad_token = tokenizer.eos_token
|
145 |
else:
|
146 |
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
147 |
+
tokenizer.pad_token = tokenizer.eos_token
|
148 |
tokenizers[tokenizer_name] = tokenizer
|
149 |
|
150 |
unified_model = UnifiedModel(list(models.values()))
|
|
|
158 |
{"text": "Necesito ayuda", "label": 2},
|
159 |
{"text": "No entiendo", "label": 0}
|
160 |
]
|
|
|
161 |
redis_client.rpush("training_queue", json.dumps({
|
162 |
"tokenizers": {tokenizer_name: tokenizer.get_vocab()},
|
163 |
"data": user_data
|
164 |
}))
|
|
|
165 |
return {"message": "Training data received. Model will be updated asynchronously."}
|
|
|
166 |
elif data.get("message"):
|
167 |
user_id = data.get("user_id")
|
168 |
text = data['message']
|
169 |
language = data.get("language", default_language)
|
|
|
170 |
if user_id not in conversation_history:
|
171 |
conversation_history[user_id] = []
|
172 |
conversation_history[user_id].append(text)
|
|
|
173 |
contextualized_text = " ".join(conversation_history[user_id][-3:])
|
|
|
174 |
tokenized_inputs = [tokenizers[name](contextualized_text, return_tensors="pt") for name in tokenizers.keys()]
|
175 |
input_ids = [tokens['input_ids'] for tokens in tokenized_inputs]
|
176 |
attention_mask = [tokens['attention_mask'] for tokens in tokenized_inputs]
|
|
|
177 |
with torch.no_grad():
|
178 |
logits = unified_model(input_ids=input_ids, attention_mask=attention_mask)
|
179 |
predicted_class = torch.argmax(logits, dim=-1).item()
|
|
|
180 |
response = chatbot_service.get_response(user_id, contextualized_text, language)
|
|
|
181 |
redis_client.rpush("training_queue", json.dumps({
|
182 |
"tokenizers": {tokenizer_name: tokenizer.get_vocab()},
|
183 |
"data": [{"text": contextualized_text, "label": predicted_class}]
|
184 |
}))
|
|
|
185 |
return {"answer": response}
|
|
|
186 |
else:
|
187 |
raise HTTPException(status_code=400, detail="Request must contain 'train' or 'message'.")
|
188 |
|
|
|
236 |
}}
|
237 |
.message {{
|
238 |
margin-bottom: 10px;
|
239 |
+
padding: 10px;
|
240 |
+
border-radius: 5px;
|
241 |
}}
|
242 |
+
.message.user {{
|
243 |
+
background-color: #e1f5fe;
|
244 |
+
text-align: right;
|
245 |
}}
|
246 |
+
.message.bot {{
|
247 |
+
background-color: #f1f1f1;
|
248 |
}}
|
249 |
#input {{
|
250 |
display: flex;
|
|
|
327 |
return HTMLResponse(content=html_code)
|
328 |
|
329 |
def train_unified_model():
|
330 |
+
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True)
|
331 |
while True:
|
|
|
332 |
training_queue = redis_client.lrange("training_queue", 0, -1)
|
333 |
if training_queue:
|
334 |
for item in training_queue:
|
335 |
item_data = json.loads(item)
|
336 |
tokenizers = {name: AutoTokenizer.from_pretrained("gpt2") for name in item_data["tokenizers"]}
|
337 |
+
for tokenizer in tokenizers.values():
|
338 |
+
tokenizer.pad_token = tokenizer.eos_token
|
339 |
data = item_data["data"]
|
340 |
dataset = SyntheticDataset(tokenizers, data)
|
341 |
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
|
|
|
367 |
training_process = multiprocessing.Process(target=train_unified_model)
|
368 |
training_process.start()
|
369 |
import uvicorn
|
370 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|