Update main.py
Browse files
main.py
CHANGED
@@ -29,6 +29,7 @@ REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
|
|
29 |
|
30 |
app = FastAPI()
|
31 |
|
|
|
32 |
language_responses = {
|
33 |
"es": {
|
34 |
0: [
|
@@ -46,6 +47,7 @@ language_responses = {
|
|
46 |
"¿En qué puedo ayudarte?",
|
47 |
"Dime, ¿qué necesitas?"
|
48 |
],
|
|
|
49 |
},
|
50 |
"en": {
|
51 |
0: [
|
@@ -63,30 +65,39 @@ language_responses = {
|
|
63 |
"What can I do for you?",
|
64 |
"Tell me, what do you need?"
|
65 |
],
|
|
|
66 |
}
|
67 |
}
|
68 |
|
69 |
-
default_language = "es"
|
70 |
|
|
|
71 |
class ChatbotService:
|
72 |
def get_response(self, user_id, message, predicted_class, language=default_language):
|
|
|
73 |
responses = language_responses.get(language, language_responses["es"])
|
|
|
|
|
74 |
if predicted_class == 1:
|
|
|
75 |
return random.choice(responses[1])
|
76 |
elif predicted_class == 2:
|
|
|
77 |
return random.choice(responses[2])
|
78 |
else:
|
|
|
79 |
return random.choice(responses[0])
|
80 |
|
81 |
chatbot_service = ChatbotService()
|
82 |
|
|
|
83 |
class UnifiedModel(nn.Module):
|
84 |
def __init__(self, models):
|
85 |
super(UnifiedModel, self).__init__()
|
86 |
self.models = nn.ModuleList(models)
|
87 |
hidden_size = self.models[0].config.hidden_size
|
88 |
-
self.projection = nn.Linear(len(models) *
|
89 |
-
self.classifier = nn.Linear(hidden_size, 3)
|
90 |
|
91 |
def forward(self, input_ids, attention_mask):
|
92 |
hidden_states = []
|
@@ -95,10 +106,10 @@ class UnifiedModel(nn.Module):
|
|
95 |
input_ids=input_id,
|
96 |
attention_mask=attn_mask
|
97 |
)
|
98 |
-
hidden_states.append(outputs.logits)
|
99 |
|
100 |
concatenated_hidden_states = torch.cat(hidden_states, dim=1)
|
101 |
-
projected_features = self.projection(concatenated_hidden_states)
|
102 |
logits = self.classifier(projected_features)
|
103 |
return logits
|
104 |
|
@@ -107,13 +118,14 @@ class UnifiedModel(nn.Module):
|
|
107 |
model_name = "unified_model"
|
108 |
model_data_bytes = redis_client.get(f"model:{model_name}")
|
109 |
if model_data_bytes:
|
110 |
-
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3)
|
111 |
model.load_state_dict(torch.load(model_data_bytes))
|
112 |
else:
|
113 |
-
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3)
|
114 |
|
115 |
-
return UnifiedModel([model, model])
|
116 |
|
|
|
117 |
class SyntheticDataset(Dataset):
|
118 |
def __init__(self, tokenizers, data):
|
119 |
self.tokenizers = tokenizers
|
@@ -134,6 +146,7 @@ class SyntheticDataset(Dataset):
|
|
134 |
tokenized["labels"] = torch.tensor(label)
|
135 |
return tokenized
|
136 |
|
|
|
137 |
conversation_history = {}
|
138 |
|
139 |
@app.post("/process")
|
@@ -151,10 +164,10 @@ async def process(request: Request):
|
|
151 |
tokenizer_data_bytes = redis_client.get(f"tokenizer:{tokenizer_name}")
|
152 |
|
153 |
if model_data_bytes:
|
154 |
-
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3)
|
155 |
model.load_state_dict(torch.load(model_data_bytes))
|
156 |
else:
|
157 |
-
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3)
|
158 |
models[model_name] = model
|
159 |
|
160 |
if tokenizer_data_bytes:
|
@@ -174,6 +187,7 @@ async def process(request: Request):
|
|
174 |
{"text": "Hola", "label": 1},
|
175 |
{"text": "Necesito ayuda", "label": 2},
|
176 |
{"text": "No entiendo", "label": 0}
|
|
|
177 |
]
|
178 |
|
179 |
redis_client.rpush("training_queue", json.dumps({
|
@@ -188,11 +202,13 @@ async def process(request: Request):
|
|
188 |
text = data['message']
|
189 |
language = data.get("language", default_language)
|
190 |
|
|
|
191 |
if user_id not in conversation_history:
|
192 |
conversation_history[user_id] = []
|
193 |
conversation_history[user_id].append(text)
|
194 |
|
195 |
-
|
|
|
196 |
|
197 |
tokenized_inputs = [tokenizers[name](contextualized_text, return_tensors="pt") for name in tokenizers.keys()]
|
198 |
input_ids = [tokens['input_ids'] for tokens in tokenized_inputs]
|
@@ -209,6 +225,7 @@ async def process(request: Request):
|
|
209 |
raise HTTPException(status_code=400, detail="Request must contain 'train' or 'message'.")
|
210 |
|
211 |
def get_chatbot_response(user_id, question, predicted_class, language):
|
|
|
212 |
if user_id not in conversation_history:
|
213 |
conversation_history[user_id] = []
|
214 |
conversation_history[user_id].append(question)
|
|
|
29 |
|
30 |
app = FastAPI()
|
31 |
|
32 |
+
# Configuración de idioma
|
33 |
language_responses = {
|
34 |
"es": {
|
35 |
0: [
|
|
|
47 |
"¿En qué puedo ayudarte?",
|
48 |
"Dime, ¿qué necesitas?"
|
49 |
],
|
50 |
+
# ... más respuestas para otras clases
|
51 |
},
|
52 |
"en": {
|
53 |
0: [
|
|
|
65 |
"What can I do for you?",
|
66 |
"Tell me, what do you need?"
|
67 |
],
|
68 |
+
# ... más respuestas para otras clases
|
69 |
}
|
70 |
}
|
71 |
|
72 |
+
default_language = "es" # Idioma predeterminado
|
73 |
|
74 |
+
# Servicio de Chatbot
|
75 |
class ChatbotService:
|
76 |
def get_response(self, user_id, message, predicted_class, language=default_language):
|
77 |
+
# Accede al diccionario de respuestas según el idioma
|
78 |
responses = language_responses.get(language, language_responses["es"])
|
79 |
+
|
80 |
+
# Lógica basada en la clase predicha
|
81 |
if predicted_class == 1:
|
82 |
+
# Seleccionar una respuesta de saludo aleatoria
|
83 |
return random.choice(responses[1])
|
84 |
elif predicted_class == 2:
|
85 |
+
# Seleccionar una respuesta de ayuda aleatoria
|
86 |
return random.choice(responses[2])
|
87 |
else:
|
88 |
+
# Seleccionar una respuesta de no comprensión aleatoria
|
89 |
return random.choice(responses[0])
|
90 |
|
91 |
chatbot_service = ChatbotService()
|
92 |
|
93 |
+
# Modelo de clasificación de texto
|
94 |
class UnifiedModel(nn.Module):
|
95 |
def __init__(self, models):
|
96 |
super(UnifiedModel, self).__init__()
|
97 |
self.models = nn.ModuleList(models)
|
98 |
hidden_size = self.models[0].config.hidden_size
|
99 |
+
self.projection = nn.Linear(len(models) * 3, 768) # Dimensión de salida corregida
|
100 |
+
self.classifier = nn.Linear(hidden_size, 3) # 3 clases
|
101 |
|
102 |
def forward(self, input_ids, attention_mask):
|
103 |
hidden_states = []
|
|
|
106 |
input_ids=input_id,
|
107 |
attention_mask=attn_mask
|
108 |
)
|
109 |
+
hidden_states.append(outputs.logits) # Usar directamente outputs.logits
|
110 |
|
111 |
concatenated_hidden_states = torch.cat(hidden_states, dim=1)
|
112 |
+
projected_features = self.projection(concatenated_hidden_states) # Proyectar a hidden_size
|
113 |
logits = self.classifier(projected_features)
|
114 |
return logits
|
115 |
|
|
|
118 |
model_name = "unified_model"
|
119 |
model_data_bytes = redis_client.get(f"model:{model_name}")
|
120 |
if model_data_bytes:
|
121 |
+
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3) # 3 clases
|
122 |
model.load_state_dict(torch.load(model_data_bytes))
|
123 |
else:
|
124 |
+
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3) # 3 clases
|
125 |
|
126 |
+
return UnifiedModel([model, model]) # Asegurar que se usa una lista de modelos, en este caso 2
|
127 |
|
128 |
+
# Dataset para entrenamiento
|
129 |
class SyntheticDataset(Dataset):
|
130 |
def __init__(self, tokenizers, data):
|
131 |
self.tokenizers = tokenizers
|
|
|
146 |
tokenized["labels"] = torch.tensor(label)
|
147 |
return tokenized
|
148 |
|
149 |
+
# Manejo de la conversación
|
150 |
conversation_history = {}
|
151 |
|
152 |
@app.post("/process")
|
|
|
164 |
tokenizer_data_bytes = redis_client.get(f"tokenizer:{tokenizer_name}")
|
165 |
|
166 |
if model_data_bytes:
|
167 |
+
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3) # 3 clases
|
168 |
model.load_state_dict(torch.load(model_data_bytes))
|
169 |
else:
|
170 |
+
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3) # 3 clases
|
171 |
models[model_name] = model
|
172 |
|
173 |
if tokenizer_data_bytes:
|
|
|
187 |
{"text": "Hola", "label": 1},
|
188 |
{"text": "Necesito ayuda", "label": 2},
|
189 |
{"text": "No entiendo", "label": 0}
|
190 |
+
# ... más ejemplos para otras clases
|
191 |
]
|
192 |
|
193 |
redis_client.rpush("training_queue", json.dumps({
|
|
|
202 |
text = data['message']
|
203 |
language = data.get("language", default_language)
|
204 |
|
205 |
+
# Memoria de Conversación
|
206 |
if user_id not in conversation_history:
|
207 |
conversation_history[user_id] = []
|
208 |
conversation_history[user_id].append(text)
|
209 |
|
210 |
+
# Concatenar el historial al mensaje actual (puedes usar otra técnica)
|
211 |
+
contextualized_text = " ".join(conversation_history[user_id][-3:]) # Usar los últimos 3 mensajes
|
212 |
|
213 |
tokenized_inputs = [tokenizers[name](contextualized_text, return_tensors="pt") for name in tokenizers.keys()]
|
214 |
input_ids = [tokens['input_ids'] for tokens in tokenized_inputs]
|
|
|
225 |
raise HTTPException(status_code=400, detail="Request must contain 'train' or 'message'.")
|
226 |
|
227 |
def get_chatbot_response(user_id, question, predicted_class, language):
|
228 |
+
# Almacenar el mensaje en el historial
|
229 |
if user_id not in conversation_history:
|
230 |
conversation_history[user_id] = []
|
231 |
conversation_history[user_id].append(question)
|