Yjhhh commited on
Commit
95d2fbc
·
verified ·
1 Parent(s): 9df8ce5

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +28 -92
main.py CHANGED
@@ -6,15 +6,12 @@ import redis
6
  from transformers import (
7
  AutoTokenizer,
8
  AutoModelForSequenceClassification,
9
- TrainingArguments,
10
  )
11
  import torch
12
  import torch.nn as nn
13
  from torch.utils.data import DataLoader, Dataset
14
  from torch.optim import AdamW
15
  from fastapi import FastAPI, HTTPException, Request
16
- from pydantic import BaseModel
17
- from typing import List, Dict
18
  from fastapi.responses import HTMLResponse
19
  import multiprocessing
20
  import time
@@ -29,75 +26,21 @@ REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
29
 
30
  app = FastAPI()
31
 
32
- # Configuración de idioma
33
- language_responses = {
34
- "es": {
35
- 0: [
36
- "Lo siento, no entiendo.",
37
- "No estoy seguro de entender lo que quieres decir.",
38
- "¿Podrías reformular tu pregunta?"
39
- ],
40
- 1: [
41
- "Hola! ¿Cómo estás?",
42
- "¡Hola! ¿Qué tal?",
43
- "Buenos días/tardes/noches, ¿cómo te va?"
44
- ],
45
- 2: [
46
- "¿Cómo te puedo ayudar?",
47
- "¿En qué puedo ayudarte?",
48
- "Dime, ¿qué necesitas?"
49
- ],
50
- # ... más respuestas para otras clases
51
- },
52
- "en": {
53
- 0: [
54
- "Sorry, I don't understand.",
55
- "I'm not sure I understand what you mean.",
56
- "Could you rephrase your question?"
57
- ],
58
- 1: [
59
- "Hello! How are you?",
60
- "Hi! What's up?",
61
- "Good morning/afternoon/evening, how are you doing?"
62
- ],
63
- 2: [
64
- "How can I help you?",
65
- "What can I do for you?",
66
- "Tell me, what do you need?"
67
- ],
68
- # ... más respuestas para otras clases
69
- }
70
- }
71
-
72
- default_language = "es" # Idioma predeterminado
73
-
74
- # Servicio de Chatbot
75
  class ChatbotService:
76
  def get_response(self, user_id, message, predicted_class, language=default_language):
77
- # Accede al diccionario de respuestas según el idioma
78
- responses = language_responses.get(language, language_responses["es"])
79
-
80
- # Lógica basada en la clase predicha
81
- if predicted_class == 1:
82
- # Seleccionar una respuesta de saludo aleatoria
83
- return random.choice(responses[1])
84
- elif predicted_class == 2:
85
- # Seleccionar una respuesta de ayuda aleatoria
86
- return random.choice(responses[2])
87
- else:
88
- # Seleccionar una respuesta de no comprensión aleatoria
89
- return random.choice(responses[0])
90
 
91
  chatbot_service = ChatbotService()
92
 
93
- # Modelo de clasificación de texto
94
  class UnifiedModel(nn.Module):
95
  def __init__(self, models):
96
  super(UnifiedModel, self).__init__()
97
  self.models = nn.ModuleList(models)
98
  hidden_size = self.models[0].config.hidden_size
99
- self.projection = nn.Linear(len(models) * 3, 768) # Dimensión de salida corregida
100
- self.classifier = nn.Linear(hidden_size, 3) # 3 clases
101
 
102
  def forward(self, input_ids, attention_mask):
103
  hidden_states = []
@@ -106,11 +49,11 @@ class UnifiedModel(nn.Module):
106
  input_ids=input_id,
107
  attention_mask=attn_mask
108
  )
109
- hidden_states.append(outputs.logits) # Usar directamente outputs.logits
110
 
111
  concatenated_hidden_states = torch.cat(hidden_states, dim=1)
112
- projected_features = self.projection(concatenated_hidden_states) # Proyectar a hidden_size
113
- logits = self.classifier(projected_features)
114
  return logits
115
 
116
  @staticmethod
@@ -118,14 +61,12 @@ class UnifiedModel(nn.Module):
118
  model_name = "unified_model"
119
  model_data_bytes = redis_client.get(f"model:{model_name}")
120
  if model_data_bytes:
121
- model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3) # 3 clases
122
  model.load_state_dict(torch.load(model_data_bytes))
123
  else:
124
- model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3) # 3 clases
125
-
126
- return UnifiedModel([model, model]) # Asegurar que se usa una lista de modelos, en este caso 2
127
 
128
- # Dataset para entrenamiento
129
  class SyntheticDataset(Dataset):
130
  def __init__(self, tokenizers, data):
131
  self.tokenizers = tokenizers
@@ -146,14 +87,13 @@ class SyntheticDataset(Dataset):
146
  tokenized["labels"] = torch.tensor(label)
147
  return tokenized
148
 
149
- # Manejo de la conversación
150
  conversation_history = {}
151
 
152
  @app.post("/process")
153
  async def process(request: Request):
154
  data = await request.json()
155
  redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True)
156
-
157
  tokenizers = {}
158
  models = {}
159
 
@@ -164,10 +104,10 @@ async def process(request: Request):
164
  tokenizer_data_bytes = redis_client.get(f"tokenizer:{tokenizer_name}")
165
 
166
  if model_data_bytes:
167
- model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3) # 3 clases
168
  model.load_state_dict(torch.load(model_data_bytes))
169
  else:
170
- model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3) # 3 clases
171
  models[model_name] = model
172
 
173
  if tokenizer_data_bytes:
@@ -179,53 +119,49 @@ async def process(request: Request):
179
 
180
  unified_model = UnifiedModel(list(models.values()))
181
  unified_model.to(torch.device("cpu"))
182
-
183
  if data.get("train"):
184
  user_data = data.get("user_data", [])
185
  if not user_data:
186
  user_data = [
187
- {"text": "Hola", "label": 1},
188
  {"text": "Necesito ayuda", "label": 2},
189
  {"text": "No entiendo", "label": 0}
190
- # ... más ejemplos para otras clases
191
  ]
192
-
193
  redis_client.rpush("training_queue", json.dumps({
194
  "tokenizers": {tokenizer_name: tokenizer.get_vocab()},
195
  "data": user_data
196
  }))
197
 
198
  return {"message": "Training data received. Model will be updated asynchronously."}
199
-
200
  elif data.get("message"):
201
  user_id = data.get("user_id")
202
  text = data['message']
203
  language = data.get("language", default_language)
204
 
205
- # Memoria de Conversación
206
  if user_id not in conversation_history:
207
  conversation_history[user_id] = []
208
  conversation_history[user_id].append(text)
209
 
210
- # Concatenar el historial al mensaje actual (puedes usar otra técnica)
211
- contextualized_text = " ".join(conversation_history[user_id][-3:]) # Usar los últimos 3 mensajes
212
 
213
  tokenized_inputs = [tokenizers[name](contextualized_text, return_tensors="pt") for name in tokenizers.keys()]
214
  input_ids = [tokens['input_ids'] for tokens in tokenized_inputs]
215
  attention_mask = [tokens['attention_mask'] for tokens in tokenized_inputs]
216
-
217
  with torch.no_grad():
218
  logits = unified_model(input_ids=input_ids, attention_mask=attention_mask)
219
  predicted_class = torch.argmax(logits, dim=-1).item()
220
-
221
  response = get_chatbot_response(user_id, text, predicted_class, language)
222
  return {"answer": response}
223
-
224
  else:
225
  raise HTTPException(status_code=400, detail="Request must contain 'train' or 'message'.")
226
 
227
  def get_chatbot_response(user_id, question, predicted_class, language):
228
- # Almacenar el mensaje en el historial
229
  if user_id not in conversation_history:
230
  conversation_history[user_id] = []
231
  conversation_history[user_id].append(question)
@@ -401,7 +337,7 @@ def push_to_redis(models, tokenizers, redis_client, model_name, tokenizer_name):
401
 
402
  def continuous_training():
403
  redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True)
404
-
405
  while True:
406
  try:
407
  data = redis_client.lpop("training_queue")
@@ -409,12 +345,12 @@ def continuous_training():
409
  data = json.loads(data)
410
  unified_model = UnifiedModel.load_model_from_redis(redis_client)
411
  unified_model.train()
412
-
413
  train_dataset = SyntheticDataset(data["tokenizers"], data["data"])
414
  train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
415
-
416
  optimizer = AdamW(unified_model.parameters(), lr=5e-5)
417
-
418
  for epoch in range(10):
419
  for batch in train_loader:
420
  input_ids = [batch[f"input_ids_{name}"].to("cpu") for name in data["tokenizers"].keys()]
@@ -425,9 +361,9 @@ def continuous_training():
425
  loss.backward()
426
  optimizer.step()
427
  optimizer.zero_grad()
428
-
429
  print(f"Epoch {epoch}, Loss {loss.item()}")
430
-
431
  push_to_redis(unified_model.models, data["tokenizers"], redis_client, "unified_model", "unified_tokenizer")
432
  time.sleep(10)
433
  except Exception as e:
 
6
  from transformers import (
7
  AutoTokenizer,
8
  AutoModelForSequenceClassification,
 
9
  )
10
  import torch
11
  import torch.nn as nn
12
  from torch.utils.data import DataLoader, Dataset
13
  from torch.optim import AdamW
14
  from fastapi import FastAPI, HTTPException, Request
 
 
15
  from fastapi.responses import HTMLResponse
16
  import multiprocessing
17
  import time
 
26
 
27
  app = FastAPI()
28
 
29
+ default_language = "es"
30
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  class ChatbotService:
32
  def get_response(self, user_id, message, predicted_class, language=default_language):
33
+ return "Respuesta por defecto."
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  chatbot_service = ChatbotService()
36
 
 
37
  class UnifiedModel(nn.Module):
38
  def __init__(self, models):
39
  super(UnifiedModel, self).__init__()
40
  self.models = nn.ModuleList(models)
41
  hidden_size = self.models[0].config.hidden_size
42
+ self.projection = nn.Linear(len(models) * 3, 768)
43
+ self.classifier = nn.Linear(hidden_size, 3)
44
 
45
  def forward(self, input_ids, attention_mask):
46
  hidden_states = []
 
49
  input_ids=input_id,
50
  attention_mask=attn_mask
51
  )
52
+ hidden_states.append(outputs.logits)
53
 
54
  concatenated_hidden_states = torch.cat(hidden_states, dim=1)
55
+ projected_features = self.projection(concatenated_hidden_states)
56
+ logits = self.classifier(projected_features)
57
  return logits
58
 
59
  @staticmethod
 
61
  model_name = "unified_model"
62
  model_data_bytes = redis_client.get(f"model:{model_name}")
63
  if model_data_bytes:
64
+ model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3)
65
  model.load_state_dict(torch.load(model_data_bytes))
66
  else:
67
+ model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3)
68
+ return UnifiedModel([model, model])
 
69
 
 
70
  class SyntheticDataset(Dataset):
71
  def __init__(self, tokenizers, data):
72
  self.tokenizers = tokenizers
 
87
  tokenized["labels"] = torch.tensor(label)
88
  return tokenized
89
 
 
90
  conversation_history = {}
91
 
92
  @app.post("/process")
93
  async def process(request: Request):
94
  data = await request.json()
95
  redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True)
96
+
97
  tokenizers = {}
98
  models = {}
99
 
 
104
  tokenizer_data_bytes = redis_client.get(f"tokenizer:{tokenizer_name}")
105
 
106
  if model_data_bytes:
107
+ model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3)
108
  model.load_state_dict(torch.load(model_data_bytes))
109
  else:
110
+ model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3)
111
  models[model_name] = model
112
 
113
  if tokenizer_data_bytes:
 
119
 
120
  unified_model = UnifiedModel(list(models.values()))
121
  unified_model.to(torch.device("cpu"))
122
+
123
  if data.get("train"):
124
  user_data = data.get("user_data", [])
125
  if not user_data:
126
  user_data = [
127
+ {"text": "Hola", "label": 1},
128
  {"text": "Necesito ayuda", "label": 2},
129
  {"text": "No entiendo", "label": 0}
 
130
  ]
131
+
132
  redis_client.rpush("training_queue", json.dumps({
133
  "tokenizers": {tokenizer_name: tokenizer.get_vocab()},
134
  "data": user_data
135
  }))
136
 
137
  return {"message": "Training data received. Model will be updated asynchronously."}
138
+
139
  elif data.get("message"):
140
  user_id = data.get("user_id")
141
  text = data['message']
142
  language = data.get("language", default_language)
143
 
 
144
  if user_id not in conversation_history:
145
  conversation_history[user_id] = []
146
  conversation_history[user_id].append(text)
147
 
148
+ contextualized_text = " ".join(conversation_history[user_id][-3:])
 
149
 
150
  tokenized_inputs = [tokenizers[name](contextualized_text, return_tensors="pt") for name in tokenizers.keys()]
151
  input_ids = [tokens['input_ids'] for tokens in tokenized_inputs]
152
  attention_mask = [tokens['attention_mask'] for tokens in tokenized_inputs]
153
+
154
  with torch.no_grad():
155
  logits = unified_model(input_ids=input_ids, attention_mask=attention_mask)
156
  predicted_class = torch.argmax(logits, dim=-1).item()
157
+
158
  response = get_chatbot_response(user_id, text, predicted_class, language)
159
  return {"answer": response}
160
+
161
  else:
162
  raise HTTPException(status_code=400, detail="Request must contain 'train' or 'message'.")
163
 
164
  def get_chatbot_response(user_id, question, predicted_class, language):
 
165
  if user_id not in conversation_history:
166
  conversation_history[user_id] = []
167
  conversation_history[user_id].append(question)
 
337
 
338
  def continuous_training():
339
  redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True)
340
+
341
  while True:
342
  try:
343
  data = redis_client.lpop("training_queue")
 
345
  data = json.loads(data)
346
  unified_model = UnifiedModel.load_model_from_redis(redis_client)
347
  unified_model.train()
348
+
349
  train_dataset = SyntheticDataset(data["tokenizers"], data["data"])
350
  train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
351
+
352
  optimizer = AdamW(unified_model.parameters(), lr=5e-5)
353
+
354
  for epoch in range(10):
355
  for batch in train_loader:
356
  input_ids = [batch[f"input_ids_{name}"].to("cpu") for name in data["tokenizers"].keys()]
 
361
  loss.backward()
362
  optimizer.step()
363
  optimizer.zero_grad()
364
+
365
  print(f"Epoch {epoch}, Loss {loss.item()}")
366
+
367
  push_to_redis(unified_model.models, data["tokenizers"], redis_client, "unified_model", "unified_tokenizer")
368
  time.sleep(10)
369
  except Exception as e: