Yjhhh commited on
Commit
9df8ce5
·
verified ·
1 Parent(s): a6520e3

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +28 -11
main.py CHANGED
@@ -29,6 +29,7 @@ REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
29
 
30
  app = FastAPI()
31
 
 
32
  language_responses = {
33
  "es": {
34
  0: [
@@ -46,6 +47,7 @@ language_responses = {
46
  "¿En qué puedo ayudarte?",
47
  "Dime, ¿qué necesitas?"
48
  ],
 
49
  },
50
  "en": {
51
  0: [
@@ -63,30 +65,39 @@ language_responses = {
63
  "What can I do for you?",
64
  "Tell me, what do you need?"
65
  ],
 
66
  }
67
  }
68
 
69
- default_language = "es"
70
 
 
71
  class ChatbotService:
72
  def get_response(self, user_id, message, predicted_class, language=default_language):
 
73
  responses = language_responses.get(language, language_responses["es"])
 
 
74
  if predicted_class == 1:
 
75
  return random.choice(responses[1])
76
  elif predicted_class == 2:
 
77
  return random.choice(responses[2])
78
  else:
 
79
  return random.choice(responses[0])
80
 
81
  chatbot_service = ChatbotService()
82
 
 
83
  class UnifiedModel(nn.Module):
84
  def __init__(self, models):
85
  super(UnifiedModel, self).__init__()
86
  self.models = nn.ModuleList(models)
87
  hidden_size = self.models[0].config.hidden_size
88
- self.projection = nn.Linear(len(models) * hidden_size, hidden_size)
89
- self.classifier = nn.Linear(hidden_size, 3)
90
 
91
  def forward(self, input_ids, attention_mask):
92
  hidden_states = []
@@ -95,10 +106,10 @@ class UnifiedModel(nn.Module):
95
  input_ids=input_id,
96
  attention_mask=attn_mask
97
  )
98
- hidden_states.append(outputs.logits)
99
 
100
  concatenated_hidden_states = torch.cat(hidden_states, dim=1)
101
- projected_features = self.projection(concatenated_hidden_states)
102
  logits = self.classifier(projected_features)
103
  return logits
104
 
@@ -107,13 +118,14 @@ class UnifiedModel(nn.Module):
107
  model_name = "unified_model"
108
  model_data_bytes = redis_client.get(f"model:{model_name}")
109
  if model_data_bytes:
110
- model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3)
111
  model.load_state_dict(torch.load(model_data_bytes))
112
  else:
113
- model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3)
114
 
115
- return UnifiedModel([model, model])
116
 
 
117
  class SyntheticDataset(Dataset):
118
  def __init__(self, tokenizers, data):
119
  self.tokenizers = tokenizers
@@ -134,6 +146,7 @@ class SyntheticDataset(Dataset):
134
  tokenized["labels"] = torch.tensor(label)
135
  return tokenized
136
 
 
137
  conversation_history = {}
138
 
139
  @app.post("/process")
@@ -151,10 +164,10 @@ async def process(request: Request):
151
  tokenizer_data_bytes = redis_client.get(f"tokenizer:{tokenizer_name}")
152
 
153
  if model_data_bytes:
154
- model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3)
155
  model.load_state_dict(torch.load(model_data_bytes))
156
  else:
157
- model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3)
158
  models[model_name] = model
159
 
160
  if tokenizer_data_bytes:
@@ -174,6 +187,7 @@ async def process(request: Request):
174
  {"text": "Hola", "label": 1},
175
  {"text": "Necesito ayuda", "label": 2},
176
  {"text": "No entiendo", "label": 0}
 
177
  ]
178
 
179
  redis_client.rpush("training_queue", json.dumps({
@@ -188,11 +202,13 @@ async def process(request: Request):
188
  text = data['message']
189
  language = data.get("language", default_language)
190
 
 
191
  if user_id not in conversation_history:
192
  conversation_history[user_id] = []
193
  conversation_history[user_id].append(text)
194
 
195
- contextualized_text = " ".join(conversation_history[user_id][-3:])
 
196
 
197
  tokenized_inputs = [tokenizers[name](contextualized_text, return_tensors="pt") for name in tokenizers.keys()]
198
  input_ids = [tokens['input_ids'] for tokens in tokenized_inputs]
@@ -209,6 +225,7 @@ async def process(request: Request):
209
  raise HTTPException(status_code=400, detail="Request must contain 'train' or 'message'.")
210
 
211
  def get_chatbot_response(user_id, question, predicted_class, language):
 
212
  if user_id not in conversation_history:
213
  conversation_history[user_id] = []
214
  conversation_history[user_id].append(question)
 
29
 
30
  app = FastAPI()
31
 
32
+ # Configuración de idioma
33
  language_responses = {
34
  "es": {
35
  0: [
 
47
  "¿En qué puedo ayudarte?",
48
  "Dime, ¿qué necesitas?"
49
  ],
50
+ # ... más respuestas para otras clases
51
  },
52
  "en": {
53
  0: [
 
65
  "What can I do for you?",
66
  "Tell me, what do you need?"
67
  ],
68
+ # ... más respuestas para otras clases
69
  }
70
  }
71
 
72
+ default_language = "es" # Idioma predeterminado
73
 
74
+ # Servicio de Chatbot
75
  class ChatbotService:
76
  def get_response(self, user_id, message, predicted_class, language=default_language):
77
+ # Accede al diccionario de respuestas según el idioma
78
  responses = language_responses.get(language, language_responses["es"])
79
+
80
+ # Lógica basada en la clase predicha
81
  if predicted_class == 1:
82
+ # Seleccionar una respuesta de saludo aleatoria
83
  return random.choice(responses[1])
84
  elif predicted_class == 2:
85
+ # Seleccionar una respuesta de ayuda aleatoria
86
  return random.choice(responses[2])
87
  else:
88
+ # Seleccionar una respuesta de no comprensión aleatoria
89
  return random.choice(responses[0])
90
 
91
  chatbot_service = ChatbotService()
92
 
93
+ # Modelo de clasificación de texto
94
  class UnifiedModel(nn.Module):
95
  def __init__(self, models):
96
  super(UnifiedModel, self).__init__()
97
  self.models = nn.ModuleList(models)
98
  hidden_size = self.models[0].config.hidden_size
99
+ self.projection = nn.Linear(len(models) * 3, 768) # Dimensión de salida corregida
100
+ self.classifier = nn.Linear(hidden_size, 3) # 3 clases
101
 
102
  def forward(self, input_ids, attention_mask):
103
  hidden_states = []
 
106
  input_ids=input_id,
107
  attention_mask=attn_mask
108
  )
109
+ hidden_states.append(outputs.logits) # Usar directamente outputs.logits
110
 
111
  concatenated_hidden_states = torch.cat(hidden_states, dim=1)
112
+ projected_features = self.projection(concatenated_hidden_states) # Proyectar a hidden_size
113
  logits = self.classifier(projected_features)
114
  return logits
115
 
 
118
  model_name = "unified_model"
119
  model_data_bytes = redis_client.get(f"model:{model_name}")
120
  if model_data_bytes:
121
+ model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3) # 3 clases
122
  model.load_state_dict(torch.load(model_data_bytes))
123
  else:
124
+ model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3) # 3 clases
125
 
126
+ return UnifiedModel([model, model]) # Asegurar que se usa una lista de modelos, en este caso 2
127
 
128
+ # Dataset para entrenamiento
129
  class SyntheticDataset(Dataset):
130
  def __init__(self, tokenizers, data):
131
  self.tokenizers = tokenizers
 
146
  tokenized["labels"] = torch.tensor(label)
147
  return tokenized
148
 
149
+ # Manejo de la conversación
150
  conversation_history = {}
151
 
152
  @app.post("/process")
 
164
  tokenizer_data_bytes = redis_client.get(f"tokenizer:{tokenizer_name}")
165
 
166
  if model_data_bytes:
167
+ model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3) # 3 clases
168
  model.load_state_dict(torch.load(model_data_bytes))
169
  else:
170
+ model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3) # 3 clases
171
  models[model_name] = model
172
 
173
  if tokenizer_data_bytes:
 
187
  {"text": "Hola", "label": 1},
188
  {"text": "Necesito ayuda", "label": 2},
189
  {"text": "No entiendo", "label": 0}
190
+ # ... más ejemplos para otras clases
191
  ]
192
 
193
  redis_client.rpush("training_queue", json.dumps({
 
202
  text = data['message']
203
  language = data.get("language", default_language)
204
 
205
+ # Memoria de Conversación
206
  if user_id not in conversation_history:
207
  conversation_history[user_id] = []
208
  conversation_history[user_id].append(text)
209
 
210
+ # Concatenar el historial al mensaje actual (puedes usar otra técnica)
211
+ contextualized_text = " ".join(conversation_history[user_id][-3:]) # Usar los últimos 3 mensajes
212
 
213
  tokenized_inputs = [tokenizers[name](contextualized_text, return_tensors="pt") for name in tokenizers.keys()]
214
  input_ids = [tokens['input_ids'] for tokens in tokenized_inputs]
 
225
  raise HTTPException(status_code=400, detail="Request must contain 'train' or 'message'.")
226
 
227
  def get_chatbot_response(user_id, question, predicted_class, language):
228
+ # Almacenar el mensaje en el historial
229
  if user_id not in conversation_history:
230
  conversation_history[user_id] = []
231
  conversation_history[user_id].append(question)