Yjhhh commited on
Commit
bde7af1
·
verified ·
1 Parent(s): 2c63956

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +14 -22
main.py CHANGED
@@ -39,16 +39,12 @@ class ChatbotService:
39
  def get_response(self, user_id, message, language=default_language):
40
  if self.model is None or self.tokenizer is None:
41
  return "El modelo aún no está listo. Por favor, inténtelo de nuevo más tarde."
42
-
43
  input_text = f"Usuario: {message} Asistente:"
44
  input_ids = self.tokenizer.encode(input_text, return_tensors="pt").to("cpu")
45
-
46
  with torch.no_grad():
47
  output = self.model.generate(input_ids=input_ids, max_length=100, num_beams=5, no_repeat_ngram_size=2, early_stopping=True)
48
-
49
  response = self.tokenizer.decode(output[0], skip_special_tokens=True)
50
  response = response.replace(input_text, "").strip()
51
-
52
  return response
53
 
54
  def load_model_from_redis(self):
@@ -64,6 +60,7 @@ class ChatbotService:
64
  if tokenizer_data_bytes:
65
  tokenizer = AutoTokenizer.from_pretrained("gpt2")
66
  tokenizer.add_tokens(json.loads(tokenizer_data_bytes))
 
67
  return tokenizer
68
  return None
69
 
@@ -82,7 +79,6 @@ class UnifiedModel(nn.Module):
82
  for model, input_id, attn_mask in zip(self.models, input_ids, attention_mask):
83
  outputs = model(input_ids=input_id, attention_mask=attn_mask)
84
  hidden_states.append(outputs.logits)
85
-
86
  concatenated_hidden_states = torch.cat(hidden_states, dim=1)
87
  projected_features = self.projection(concatenated_hidden_states)
88
  logits = self.classifier(projected_features)
@@ -145,8 +141,10 @@ async def process(request: Request):
145
  if tokenizer_data_bytes:
146
  tokenizer = AutoTokenizer.from_pretrained("gpt2")
147
  tokenizer.add_tokens(json.loads(tokenizer_data_bytes))
 
148
  else:
149
  tokenizer = AutoTokenizer.from_pretrained("gpt2")
 
150
  tokenizers[tokenizer_name] = tokenizer
151
 
152
  unified_model = UnifiedModel(list(models.values()))
@@ -160,42 +158,31 @@ async def process(request: Request):
160
  {"text": "Necesito ayuda", "label": 2},
161
  {"text": "No entiendo", "label": 0}
162
  ]
163
-
164
  redis_client.rpush("training_queue", json.dumps({
165
  "tokenizers": {tokenizer_name: tokenizer.get_vocab()},
166
  "data": user_data
167
  }))
168
-
169
  return {"message": "Training data received. Model will be updated asynchronously."}
170
-
171
  elif data.get("message"):
172
  user_id = data.get("user_id")
173
  text = data['message']
174
  language = data.get("language", default_language)
175
-
176
  if user_id not in conversation_history:
177
  conversation_history[user_id] = []
178
  conversation_history[user_id].append(text)
179
-
180
  contextualized_text = " ".join(conversation_history[user_id][-3:])
181
-
182
  tokenized_inputs = [tokenizers[name](contextualized_text, return_tensors="pt") for name in tokenizers.keys()]
183
  input_ids = [tokens['input_ids'] for tokens in tokenized_inputs]
184
  attention_mask = [tokens['attention_mask'] for tokens in tokenized_inputs]
185
-
186
  with torch.no_grad():
187
  logits = unified_model(input_ids=input_ids, attention_mask=attention_mask)
188
  predicted_class = torch.argmax(logits, dim=-1).item()
189
-
190
  response = chatbot_service.get_response(user_id, contextualized_text, language)
191
-
192
  redis_client.rpush("training_queue", json.dumps({
193
  "tokenizers": {tokenizer_name: tokenizer.get_vocab()},
194
  "data": [{"text": contextualized_text, "label": predicted_class}]
195
  }))
196
-
197
  return {"answer": response}
198
-
199
  else:
200
  raise HTTPException(status_code=400, detail="Request must contain 'train' or 'message'.")
201
 
@@ -249,12 +236,15 @@ async def get_home():
249
  }}
250
  .message {{
251
  margin-bottom: 10px;
 
 
252
  }}
253
- .user {{
254
- color: #007bff;
 
255
  }}
256
- .bot {{
257
- color: #28a745;
258
  }}
259
  #input {{
260
  display: flex;
@@ -337,13 +327,15 @@ async def get_home():
337
  return HTMLResponse(content=html_code)
338
 
339
  def train_unified_model():
 
340
  while True:
341
- redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True)
342
  training_queue = redis_client.lrange("training_queue", 0, -1)
343
  if training_queue:
344
  for item in training_queue:
345
  item_data = json.loads(item)
346
  tokenizers = {name: AutoTokenizer.from_pretrained("gpt2") for name in item_data["tokenizers"]}
 
 
347
  data = item_data["data"]
348
  dataset = SyntheticDataset(tokenizers, data)
349
  dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
@@ -375,4 +367,4 @@ if __name__ == "__main__":
375
  training_process = multiprocessing.Process(target=train_unified_model)
376
  training_process.start()
377
  import uvicorn
378
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
39
  def get_response(self, user_id, message, language=default_language):
40
  if self.model is None or self.tokenizer is None:
41
  return "El modelo aún no está listo. Por favor, inténtelo de nuevo más tarde."
 
42
  input_text = f"Usuario: {message} Asistente:"
43
  input_ids = self.tokenizer.encode(input_text, return_tensors="pt").to("cpu")
 
44
  with torch.no_grad():
45
  output = self.model.generate(input_ids=input_ids, max_length=100, num_beams=5, no_repeat_ngram_size=2, early_stopping=True)
 
46
  response = self.tokenizer.decode(output[0], skip_special_tokens=True)
47
  response = response.replace(input_text, "").strip()
 
48
  return response
49
 
50
  def load_model_from_redis(self):
 
60
  if tokenizer_data_bytes:
61
  tokenizer = AutoTokenizer.from_pretrained("gpt2")
62
  tokenizer.add_tokens(json.loads(tokenizer_data_bytes))
63
+ tokenizer.pad_token = tokenizer.eos_token
64
  return tokenizer
65
  return None
66
 
 
79
  for model, input_id, attn_mask in zip(self.models, input_ids, attention_mask):
80
  outputs = model(input_ids=input_id, attention_mask=attn_mask)
81
  hidden_states.append(outputs.logits)
 
82
  concatenated_hidden_states = torch.cat(hidden_states, dim=1)
83
  projected_features = self.projection(concatenated_hidden_states)
84
  logits = self.classifier(projected_features)
 
141
  if tokenizer_data_bytes:
142
  tokenizer = AutoTokenizer.from_pretrained("gpt2")
143
  tokenizer.add_tokens(json.loads(tokenizer_data_bytes))
144
+ tokenizer.pad_token = tokenizer.eos_token
145
  else:
146
  tokenizer = AutoTokenizer.from_pretrained("gpt2")
147
+ tokenizer.pad_token = tokenizer.eos_token
148
  tokenizers[tokenizer_name] = tokenizer
149
 
150
  unified_model = UnifiedModel(list(models.values()))
 
158
  {"text": "Necesito ayuda", "label": 2},
159
  {"text": "No entiendo", "label": 0}
160
  ]
 
161
  redis_client.rpush("training_queue", json.dumps({
162
  "tokenizers": {tokenizer_name: tokenizer.get_vocab()},
163
  "data": user_data
164
  }))
 
165
  return {"message": "Training data received. Model will be updated asynchronously."}
 
166
  elif data.get("message"):
167
  user_id = data.get("user_id")
168
  text = data['message']
169
  language = data.get("language", default_language)
 
170
  if user_id not in conversation_history:
171
  conversation_history[user_id] = []
172
  conversation_history[user_id].append(text)
 
173
  contextualized_text = " ".join(conversation_history[user_id][-3:])
 
174
  tokenized_inputs = [tokenizers[name](contextualized_text, return_tensors="pt") for name in tokenizers.keys()]
175
  input_ids = [tokens['input_ids'] for tokens in tokenized_inputs]
176
  attention_mask = [tokens['attention_mask'] for tokens in tokenized_inputs]
 
177
  with torch.no_grad():
178
  logits = unified_model(input_ids=input_ids, attention_mask=attention_mask)
179
  predicted_class = torch.argmax(logits, dim=-1).item()
 
180
  response = chatbot_service.get_response(user_id, contextualized_text, language)
 
181
  redis_client.rpush("training_queue", json.dumps({
182
  "tokenizers": {tokenizer_name: tokenizer.get_vocab()},
183
  "data": [{"text": contextualized_text, "label": predicted_class}]
184
  }))
 
185
  return {"answer": response}
 
186
  else:
187
  raise HTTPException(status_code=400, detail="Request must contain 'train' or 'message'.")
188
 
 
236
  }}
237
  .message {{
238
  margin-bottom: 10px;
239
+ padding: 10px;
240
+ border-radius: 5px;
241
  }}
242
+ .message.user {{
243
+ background-color: #e1f5fe;
244
+ text-align: right;
245
  }}
246
+ .message.bot {{
247
+ background-color: #f1f1f1;
248
  }}
249
  #input {{
250
  display: flex;
 
327
  return HTMLResponse(content=html_code)
328
 
329
  def train_unified_model():
330
+ redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True)
331
  while True:
 
332
  training_queue = redis_client.lrange("training_queue", 0, -1)
333
  if training_queue:
334
  for item in training_queue:
335
  item_data = json.loads(item)
336
  tokenizers = {name: AutoTokenizer.from_pretrained("gpt2") for name in item_data["tokenizers"]}
337
+ for tokenizer in tokenizers.values():
338
+ tokenizer.pad_token = tokenizer.eos_token
339
  data = item_data["data"]
340
  dataset = SyntheticDataset(tokenizers, data)
341
  dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
 
367
  training_process = multiprocessing.Process(target=train_unified_model)
368
  training_process.start()
369
  import uvicorn
370
+ uvicorn.run(app, host="0.0.0.0", port=8000)