Update main.py
Browse files
main.py
CHANGED
@@ -6,6 +6,7 @@ import redis
|
|
6 |
from transformers import (
|
7 |
AutoTokenizer,
|
8 |
AutoModelForSequenceClassification,
|
|
|
9 |
)
|
10 |
import torch
|
11 |
import torch.nn as nn
|
@@ -29,8 +30,46 @@ app = FastAPI()
|
|
29 |
default_language = "es"
|
30 |
|
31 |
class ChatbotService:
|
32 |
-
def
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
chatbot_service = ChatbotService()
|
36 |
|
@@ -155,7 +194,13 @@ async def process(request: Request):
|
|
155 |
logits = unified_model(input_ids=input_ids, attention_mask=attention_mask)
|
156 |
predicted_class = torch.argmax(logits, dim=-1).item()
|
157 |
|
158 |
-
response =
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
return {"answer": response}
|
160 |
|
161 |
else:
|
@@ -166,7 +211,7 @@ def get_chatbot_response(user_id, question, predicted_class, language):
|
|
166 |
conversation_history[user_id] = []
|
167 |
conversation_history[user_id].append(question)
|
168 |
|
169 |
-
return chatbot_service.get_response(user_id, question,
|
170 |
|
171 |
@app.get("/")
|
172 |
async def get_home():
|
@@ -364,7 +409,13 @@ def continuous_training():
|
|
364 |
|
365 |
print(f"Epoch {epoch}, Loss {loss.item()}")
|
366 |
|
367 |
-
push_to_redis(
|
|
|
|
|
|
|
|
|
|
|
|
|
368 |
time.sleep(10)
|
369 |
except Exception as e:
|
370 |
print(f"Error in continuous training: {e}")
|
|
|
6 |
from transformers import (
|
7 |
AutoTokenizer,
|
8 |
AutoModelForSequenceClassification,
|
9 |
+
AutoModelForCausalLM,
|
10 |
)
|
11 |
import torch
|
12 |
import torch.nn as nn
|
|
|
30 |
default_language = "es"
|
31 |
|
32 |
class ChatbotService:
|
33 |
+
def __init__(self):
|
34 |
+
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True)
|
35 |
+
self.model_name = "response_model"
|
36 |
+
self.tokenizer_name = "response_tokenizer"
|
37 |
+
|
38 |
+
def get_response(self, user_id, message, language=default_language):
|
39 |
+
model = self.load_model_from_redis()
|
40 |
+
tokenizer = self.load_tokenizer_from_redis()
|
41 |
+
|
42 |
+
if model is None or tokenizer is None:
|
43 |
+
return "El modelo aún no está listo. Por favor, inténtelo de nuevo más tarde."
|
44 |
+
|
45 |
+
input_text = f"Usuario: {message} Asistente:"
|
46 |
+
input_ids = tokenizer.encode(input_text, return_tensors="pt").to("cpu")
|
47 |
+
|
48 |
+
with torch.no_grad():
|
49 |
+
output = model.generate(input_ids=input_ids, max_length=100, num_beams=5, no_repeat_ngram_size=2, early_stopping=True)
|
50 |
+
|
51 |
+
response = tokenizer.decode(output[0], skip_special_tokens=True)
|
52 |
+
response = response.replace(input_text, "").strip()
|
53 |
+
|
54 |
+
return response
|
55 |
+
|
56 |
+
def load_model_from_redis(self):
|
57 |
+
model_data_bytes = self.redis_client.get(f"model:{self.model_name}")
|
58 |
+
if model_data_bytes:
|
59 |
+
model = AutoModelForCausalLM.from_pretrained("gpt2")
|
60 |
+
model.load_state_dict(torch.load(model_data_bytes))
|
61 |
+
return model
|
62 |
+
else:
|
63 |
+
return None
|
64 |
+
|
65 |
+
def load_tokenizer_from_redis(self):
|
66 |
+
tokenizer_data_bytes = self.redis_client.get(f"tokenizer:{self.tokenizer_name}")
|
67 |
+
if tokenizer_data_bytes:
|
68 |
+
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
69 |
+
tokenizer.add_tokens(json.loads(tokenizer_data_bytes))
|
70 |
+
return tokenizer
|
71 |
+
else:
|
72 |
+
return None
|
73 |
|
74 |
chatbot_service = ChatbotService()
|
75 |
|
|
|
194 |
logits = unified_model(input_ids=input_ids, attention_mask=attention_mask)
|
195 |
predicted_class = torch.argmax(logits, dim=-1).item()
|
196 |
|
197 |
+
response = chatbot_service.get_response(user_id, contextualized_text, language)
|
198 |
+
|
199 |
+
redis_client.rpush("training_queue", json.dumps({
|
200 |
+
"tokenizers": {tokenizer_name: tokenizer.get_vocab()},
|
201 |
+
"data": [{"text": contextualized_text, "label": predicted_class}]
|
202 |
+
}))
|
203 |
+
|
204 |
return {"answer": response}
|
205 |
|
206 |
else:
|
|
|
211 |
conversation_history[user_id] = []
|
212 |
conversation_history[user_id].append(question)
|
213 |
|
214 |
+
return chatbot_service.get_response(user_id, question, language)
|
215 |
|
216 |
@app.get("/")
|
217 |
async def get_home():
|
|
|
409 |
|
410 |
print(f"Epoch {epoch}, Loss {loss.item()}")
|
411 |
|
412 |
+
push_to_redis(
|
413 |
+
{"response_model": unified_model},
|
414 |
+
{"response_tokenizer": tokenizer},
|
415 |
+
redis_client,
|
416 |
+
"response_model",
|
417 |
+
"response_tokenizer",
|
418 |
+
)
|
419 |
time.sleep(10)
|
420 |
except Exception as e:
|
421 |
print(f"Error in continuous training: {e}")
|