Yjhhh commited on
Commit
ac27d05
·
verified ·
1 Parent(s): 95d2fbc

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +56 -5
main.py CHANGED
@@ -6,6 +6,7 @@ import redis
6
  from transformers import (
7
  AutoTokenizer,
8
  AutoModelForSequenceClassification,
 
9
  )
10
  import torch
11
  import torch.nn as nn
@@ -29,8 +30,46 @@ app = FastAPI()
29
  default_language = "es"
30
 
31
  class ChatbotService:
32
- def get_response(self, user_id, message, predicted_class, language=default_language):
33
- return "Respuesta por defecto."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  chatbot_service = ChatbotService()
36
 
@@ -155,7 +194,13 @@ async def process(request: Request):
155
  logits = unified_model(input_ids=input_ids, attention_mask=attention_mask)
156
  predicted_class = torch.argmax(logits, dim=-1).item()
157
 
158
- response = get_chatbot_response(user_id, text, predicted_class, language)
 
 
 
 
 
 
159
  return {"answer": response}
160
 
161
  else:
@@ -166,7 +211,7 @@ def get_chatbot_response(user_id, question, predicted_class, language):
166
  conversation_history[user_id] = []
167
  conversation_history[user_id].append(question)
168
 
169
- return chatbot_service.get_response(user_id, question, predicted_class, language)
170
 
171
  @app.get("/")
172
  async def get_home():
@@ -364,7 +409,13 @@ def continuous_training():
364
 
365
  print(f"Epoch {epoch}, Loss {loss.item()}")
366
 
367
- push_to_redis(unified_model.models, data["tokenizers"], redis_client, "unified_model", "unified_tokenizer")
 
 
 
 
 
 
368
  time.sleep(10)
369
  except Exception as e:
370
  print(f"Error in continuous training: {e}")
 
6
  from transformers import (
7
  AutoTokenizer,
8
  AutoModelForSequenceClassification,
9
+ AutoModelForCausalLM,
10
  )
11
  import torch
12
  import torch.nn as nn
 
30
  default_language = "es"
31
 
32
  class ChatbotService:
33
+ def __init__(self):
34
+ self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True)
35
+ self.model_name = "response_model"
36
+ self.tokenizer_name = "response_tokenizer"
37
+
38
+ def get_response(self, user_id, message, language=default_language):
39
+ model = self.load_model_from_redis()
40
+ tokenizer = self.load_tokenizer_from_redis()
41
+
42
+ if model is None or tokenizer is None:
43
+ return "El modelo aún no está listo. Por favor, inténtelo de nuevo más tarde."
44
+
45
+ input_text = f"Usuario: {message} Asistente:"
46
+ input_ids = tokenizer.encode(input_text, return_tensors="pt").to("cpu")
47
+
48
+ with torch.no_grad():
49
+ output = model.generate(input_ids=input_ids, max_length=100, num_beams=5, no_repeat_ngram_size=2, early_stopping=True)
50
+
51
+ response = tokenizer.decode(output[0], skip_special_tokens=True)
52
+ response = response.replace(input_text, "").strip()
53
+
54
+ return response
55
+
56
+ def load_model_from_redis(self):
57
+ model_data_bytes = self.redis_client.get(f"model:{self.model_name}")
58
+ if model_data_bytes:
59
+ model = AutoModelForCausalLM.from_pretrained("gpt2")
60
+ model.load_state_dict(torch.load(model_data_bytes))
61
+ return model
62
+ else:
63
+ return None
64
+
65
+ def load_tokenizer_from_redis(self):
66
+ tokenizer_data_bytes = self.redis_client.get(f"tokenizer:{self.tokenizer_name}")
67
+ if tokenizer_data_bytes:
68
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
69
+ tokenizer.add_tokens(json.loads(tokenizer_data_bytes))
70
+ return tokenizer
71
+ else:
72
+ return None
73
 
74
  chatbot_service = ChatbotService()
75
 
 
194
  logits = unified_model(input_ids=input_ids, attention_mask=attention_mask)
195
  predicted_class = torch.argmax(logits, dim=-1).item()
196
 
197
+ response = chatbot_service.get_response(user_id, contextualized_text, language)
198
+
199
+ redis_client.rpush("training_queue", json.dumps({
200
+ "tokenizers": {tokenizer_name: tokenizer.get_vocab()},
201
+ "data": [{"text": contextualized_text, "label": predicted_class}]
202
+ }))
203
+
204
  return {"answer": response}
205
 
206
  else:
 
211
  conversation_history[user_id] = []
212
  conversation_history[user_id].append(question)
213
 
214
+ return chatbot_service.get_response(user_id, question, language)
215
 
216
  @app.get("/")
217
  async def get_home():
 
409
 
410
  print(f"Epoch {epoch}, Loss {loss.item()}")
411
 
412
+ push_to_redis(
413
+ {"response_model": unified_model},
414
+ {"response_tokenizer": tokenizer},
415
+ redis_client,
416
+ "response_model",
417
+ "response_tokenizer",
418
+ )
419
  time.sleep(10)
420
  except Exception as e:
421
  print(f"Error in continuous training: {e}")