Yjhhh commited on
Commit
0b404fc
·
verified ·
1 Parent(s): ac5b7b0

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +237 -109
main.py CHANGED
@@ -18,6 +18,8 @@ from typing import List, Dict
18
  from fastapi.responses import HTMLResponse
19
  import multiprocessing
20
  import time
 
 
21
 
22
  load_dotenv()
23
 
@@ -27,12 +29,74 @@ REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
27
 
28
  app = FastAPI()
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  class UnifiedModel(nn.Module):
31
  def __init__(self, models):
32
  super(UnifiedModel, self).__init__()
33
  self.models = nn.ModuleList(models)
34
  hidden_size = self.models[0].config.hidden_size
35
- self.classifier = nn.Linear(len(models) * hidden_size, 2)
36
 
37
  def forward(self, input_ids, attention_mask):
38
  hidden_states = []
@@ -51,12 +115,13 @@ class UnifiedModel(nn.Module):
51
  model_name = "unified_model"
52
  model_data_bytes = redis_client.get(f"model:{model_name}")
53
  if model_data_bytes:
54
- model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=2)
55
  model.load_state_dict(torch.load(model_data_bytes))
56
  else:
57
- model = AutoModelForSequenceClassification.from_pretrained("gpt2")
58
  return UnifiedModel([model])
59
 
 
60
  class SyntheticDataset(Dataset):
61
  def __init__(self, tokenizers, data):
62
  self.tokenizers = tokenizers
@@ -77,22 +142,11 @@ class SyntheticDataset(Dataset):
77
  tokenized["labels"] = torch.tensor(label)
78
  return tokenized
79
 
 
 
 
80
  @app.post("/process")
81
  async def process(request: Request):
82
- """
83
- Processes requests for training and prediction.
84
-
85
- Args:
86
- request (Request): The incoming request object.
87
-
88
- Returns:
89
- dict: A dictionary containing either a message indicating successful
90
- training data submission or the model's prediction.
91
-
92
- Raises:
93
- HTTPException: If the request does not contain 'train' or 'predict'
94
- keys.
95
- """
96
  data = await request.json()
97
  redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True)
98
 
@@ -106,10 +160,10 @@ async def process(request: Request):
106
  tokenizer_data_bytes = redis_client.get(f"tokenizer:{tokenizer_name}")
107
 
108
  if model_data_bytes:
109
- model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=2)
110
  model.load_state_dict(torch.load(model_data_bytes))
111
  else:
112
- model = AutoModelForSequenceClassification.from_pretrained("gpt2")
113
  models[model_name] = model
114
 
115
  if tokenizer_data_bytes:
@@ -125,9 +179,13 @@ async def process(request: Request):
125
  if data.get("train"):
126
  user_data = data.get("user_data", [])
127
  if not user_data:
128
- user_data = [{"text": "Sample text for automatic training.", "label": 0}]
 
 
 
 
 
129
 
130
- # Add user data to Redis queue for asynchronous training
131
  redis_client.rpush("training_queue", json.dumps({
132
  "tokenizers": {tokenizer_name: tokenizer.get_vocab()},
133
  "data": user_data
@@ -135,9 +193,20 @@ async def process(request: Request):
135
 
136
  return {"message": "Training data received. Model will be updated asynchronously."}
137
 
138
- elif data.get("predict"):
139
- text = data['text']
140
- tokenized_inputs = [tokenizers[name](text, return_tensors="pt") for name in tokenizers.keys()]
 
 
 
 
 
 
 
 
 
 
 
141
  input_ids = [tokens['input_ids'] for tokens in tokenized_inputs]
142
  attention_mask = [tokens['attention_mask'] for tokens in tokenized_inputs]
143
 
@@ -145,114 +214,178 @@ async def process(request: Request):
145
  logits = unified_model(input_ids=input_ids, attention_mask=attention_mask)
146
  predicted_class = torch.argmax(logits, dim=-1).item()
147
 
148
- return {"prediction": predicted_class}
 
149
 
150
  else:
151
- raise HTTPException(status_code=400, detail="Request must contain 'train' or 'predict'.")
152
-
153
- @app.post("/external_answer")
154
- async def external_answer(request: Request):
155
- """
156
- Provides an answer to a question using the unified model and triggers
157
- asynchronous training with the new question-answer pair.
158
 
159
- Args:
160
- request (Request): The incoming request object containing the question.
 
 
 
161
 
162
- Returns:
163
- dict: A dictionary containing the answer to the question.
164
-
165
- Raises:
166
- HTTPException: If the request does not contain a 'question' key.
167
- """
168
- data = await request.json()
169
- redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True)
170
-
171
- question = data.get('question')
172
- if not question:
173
- raise HTTPException(status_code=400, detail="Question is required.")
174
-
175
- # Load the model and tokenizer from Redis
176
- unified_model = UnifiedModel.load_model_from_redis(redis_client)
177
- unified_model.to(torch.device("cpu"))
178
-
179
- tokenizer_data_bytes = redis_client.get(f"tokenizer:unified_tokenizer")
180
- if tokenizer_data_bytes:
181
- tokenizer = AutoTokenizer.from_pretrained("gpt2")
182
- tokenizer.add_tokens(json.loads(tokenizer_data_bytes))
183
- else:
184
- tokenizer = AutoTokenizer.from_pretrained("gpt2")
185
-
186
- tokenized_input = tokenizer(question, return_tensors="pt")
187
- input_ids = tokenized_input['input_ids']
188
- attention_mask = tokenized_input['attention_mask']
189
-
190
- with torch.no_grad():
191
- logits = unified_model(input_ids=[input_ids], attention_mask=[attention_mask])
192
- predicted_class = torch.argmax(logits, dim=-1).item()
193
- response = {"answer": f"Response to '{question}' is class {predicted_class}"}
194
-
195
- # Asynchronously train on the new data point
196
- redis_client.rpush("training_queue", json.dumps({
197
- "tokenizers": {"unified_tokenizer": tokenizer.get_vocab()},
198
- "data": [{"text": question, "label": predicted_class}]
199
- }))
200
-
201
- return response
202
 
203
  @app.get("/")
204
  async def get_home():
205
- """
206
- Serves a basic HTML page as the home route.
207
-
208
- Returns:
209
- HTMLResponse: The HTML content of the home page.
210
- """
211
- html_code = """
212
  <!DOCTYPE html>
213
  <html>
214
  <head>
215
  <meta charset="UTF-8">
216
  <title>Chatbot</title>
217
  <style>
218
- body {
219
- font-family: Arial, sans-serif;
220
  background-color: #f4f4f9;
221
  margin: 0;
222
  padding: 0;
223
- }
224
- .container {
225
- max-width: 1200px;
226
- margin: 0 auto;
227
- padding: 20px;
228
- }
229
- h1 {
 
 
 
 
 
 
 
 
 
230
  color: #333;
231
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  </style>
233
  </head>
234
  <body>
235
  <div class="container">
236
- <h1>Chatbot Interface</h1>
 
 
 
 
 
 
237
  </div>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  </body>
239
  </html>
240
  """
241
  return HTMLResponse(content=html_code)
242
 
243
  def push_to_redis(models, tokenizers, redis_client, model_name, tokenizer_name):
244
- """
245
- Saves the given models and tokenizers to Redis.
246
-
247
- Args:
248
- models (dict): A dictionary of model names and their corresponding
249
- PyTorch models.
250
- tokenizers (dict): A dictionary of tokenizer names and their
251
- corresponding tokenizers.
252
- redis_client: The Redis client instance.
253
- model_name (str): The base name to use for saving the models.
254
- tokenizer_name (str): The base name to use for saving the tokenizers.
255
- """
256
  for model_name, model in models.items():
257
  torch.save(model.state_dict(), model_name)
258
  redis_client.set(f"model:{model_name}", open(model_name, "rb").read())
@@ -262,9 +395,6 @@ def push_to_redis(models, tokenizers, redis_client, model_name, tokenizer_name):
262
  redis_client.set(f"tokenizer:{tokenizer_name}", json.dumps(tokens))
263
 
264
  def continuous_training():
265
- """
266
- Continuously checks for new training data in Redis and updates the model.
267
- """
268
  redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True)
269
 
270
  while True:
@@ -300,10 +430,8 @@ def continuous_training():
300
  time.sleep(5)
301
 
302
  if __name__ == "__main__":
303
- # Start the continuous training process in a separate process
304
  training_process = multiprocessing.Process(target=continuous_training)
305
  training_process.start()
306
 
307
- # Run the FastAPI app
308
  import uvicorn
309
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
18
  from fastapi.responses import HTMLResponse
19
  import multiprocessing
20
  import time
21
+ import uuid
22
+ import random
23
 
24
  load_dotenv()
25
 
 
29
 
30
  app = FastAPI()
31
 
32
+ # Configuración de idioma
33
+ language_responses = {
34
+ "es": {
35
+ 0: [
36
+ "Lo siento, no entiendo.",
37
+ "No estoy seguro de entender lo que quieres decir.",
38
+ "¿Podrías reformular tu pregunta?"
39
+ ],
40
+ 1: [
41
+ "Hola! ¿Cómo estás?",
42
+ "¡Hola! ¿Qué tal?",
43
+ "Buenos días/tardes/noches, ¿cómo te va?"
44
+ ],
45
+ 2: [
46
+ "¿Cómo te puedo ayudar?",
47
+ "¿En qué puedo ayudarte?",
48
+ "Dime, ¿qué necesitas?"
49
+ ],
50
+ # ... más respuestas para otras clases
51
+ },
52
+ "en": {
53
+ 0: [
54
+ "Sorry, I don't understand.",
55
+ "I'm not sure I understand what you mean.",
56
+ "Could you rephrase your question?"
57
+ ],
58
+ 1: [
59
+ "Hello! How are you?",
60
+ "Hi! What's up?",
61
+ "Good morning/afternoon/evening, how are you doing?"
62
+ ],
63
+ 2: [
64
+ "How can I help you?",
65
+ "What can I do for you?",
66
+ "Tell me, what do you need?"
67
+ ],
68
+ # ... más respuestas para otras clases
69
+ }
70
+ }
71
+
72
+ default_language = "es" # Idioma predeterminado
73
+
74
+ # Servicio de Chatbot
75
+ class ChatbotService:
76
+ def get_response(self, user_id, message, predicted_class, language=default_language):
77
+ # Accede al diccionario de respuestas según el idioma
78
+ responses = language_responses.get(language, language_responses["es"])
79
+
80
+ # Lógica basada en la clase predicha
81
+ if predicted_class == 1:
82
+ # Seleccionar una respuesta de saludo aleatoria
83
+ return random.choice(responses[1])
84
+ elif predicted_class == 2:
85
+ # Seleccionar una respuesta de ayuda aleatoria
86
+ return random.choice(responses[2])
87
+ else:
88
+ # Seleccionar una respuesta de no comprensión aleatoria
89
+ return random.choice(responses[0])
90
+
91
+ chatbot_service = ChatbotService()
92
+
93
+ # Modelo de clasificación de texto
94
  class UnifiedModel(nn.Module):
95
  def __init__(self, models):
96
  super(UnifiedModel, self).__init__()
97
  self.models = nn.ModuleList(models)
98
  hidden_size = self.models[0].config.hidden_size
99
+ self.classifier = nn.Linear(len(models) * hidden_size, 3) # 3 clases
100
 
101
  def forward(self, input_ids, attention_mask):
102
  hidden_states = []
 
115
  model_name = "unified_model"
116
  model_data_bytes = redis_client.get(f"model:{model_name}")
117
  if model_data_bytes:
118
+ model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3) # 3 clases
119
  model.load_state_dict(torch.load(model_data_bytes))
120
  else:
121
+ model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3) # 3 clases
122
  return UnifiedModel([model])
123
 
124
+ # Dataset para entrenamiento
125
  class SyntheticDataset(Dataset):
126
  def __init__(self, tokenizers, data):
127
  self.tokenizers = tokenizers
 
142
  tokenized["labels"] = torch.tensor(label)
143
  return tokenized
144
 
145
+ # Manejo de la conversación
146
+ conversation_history = {}
147
+
148
  @app.post("/process")
149
  async def process(request: Request):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  data = await request.json()
151
  redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True)
152
 
 
160
  tokenizer_data_bytes = redis_client.get(f"tokenizer:{tokenizer_name}")
161
 
162
  if model_data_bytes:
163
+ model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3) # 3 clases
164
  model.load_state_dict(torch.load(model_data_bytes))
165
  else:
166
+ model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3) # 3 clases
167
  models[model_name] = model
168
 
169
  if tokenizer_data_bytes:
 
179
  if data.get("train"):
180
  user_data = data.get("user_data", [])
181
  if not user_data:
182
+ user_data = [
183
+ {"text": "Hola", "label": 1},
184
+ {"text": "Necesito ayuda", "label": 2},
185
+ {"text": "No entiendo", "label": 0}
186
+ # ... más ejemplos para otras clases
187
+ ]
188
 
 
189
  redis_client.rpush("training_queue", json.dumps({
190
  "tokenizers": {tokenizer_name: tokenizer.get_vocab()},
191
  "data": user_data
 
193
 
194
  return {"message": "Training data received. Model will be updated asynchronously."}
195
 
196
+ elif data.get("message"):
197
+ user_id = data.get("user_id")
198
+ text = data['message']
199
+ language = data.get("language", default_language)
200
+
201
+ # Memoria de Conversación
202
+ if user_id not in conversation_history:
203
+ conversation_history[user_id] = []
204
+ conversation_history[user_id].append(text)
205
+
206
+ # Concatenar el historial al mensaje actual (puedes usar otra técnica)
207
+ contextualized_text = " ".join(conversation_history[user_id][-3:]) # Usar los últimos 3 mensajes
208
+
209
+ tokenized_inputs = [tokenizers[name](contextualized_text, return_tensors="pt") for name in tokenizers.keys()]
210
  input_ids = [tokens['input_ids'] for tokens in tokenized_inputs]
211
  attention_mask = [tokens['attention_mask'] for tokens in tokenized_inputs]
212
 
 
214
  logits = unified_model(input_ids=input_ids, attention_mask=attention_mask)
215
  predicted_class = torch.argmax(logits, dim=-1).item()
216
 
217
+ response = get_chatbot_response(user_id, text, predicted_class, language)
218
+ return {"answer": response}
219
 
220
  else:
221
+ raise HTTPException(status_code=400, detail="Request must contain 'train' or 'message'.")
 
 
 
 
 
 
222
 
223
+ def get_chatbot_response(user_id, question, predicted_class, language):
224
+ # Almacenar el mensaje en el historial
225
+ if user_id not in conversation_history:
226
+ conversation_history[user_id] = []
227
+ conversation_history[user_id].append(question)
228
 
229
+ return chatbot_service.get_response(user_id, question, predicted_class, language)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
 
231
  @app.get("/")
232
  async def get_home():
233
+ user_id = str(uuid.uuid4())
234
+ html_code = f"""
 
 
 
 
 
235
  <!DOCTYPE html>
236
  <html>
237
  <head>
238
  <meta charset="UTF-8">
239
  <title>Chatbot</title>
240
  <style>
241
+ body {{
242
+ font-family: 'Arial', sans-serif;
243
  background-color: #f4f4f9;
244
  margin: 0;
245
  padding: 0;
246
+ display: flex;
247
+ align-items: center;
248
+ justify-content: center;
249
+ min-height: 100vh;
250
+ }}
251
+
252
+ .container {{
253
+ background-color: #fff;
254
+ border-radius: 10px;
255
+ box-shadow: 0 2px 5px rgba(0, 0, 0, 0.1);
256
+ overflow: hidden;
257
+ width: 400px;
258
+ max-width: 90%;
259
+ }}
260
+
261
+ h1 {{
262
  color: #333;
263
+ text-align: center;
264
+ padding: 20px;
265
+ margin: 0;
266
+ background-color: #f8f9fa;
267
+ border-bottom: 1px solid #eee;
268
+ }}
269
+
270
+ #chatbox {{
271
+ height: 400px;
272
+ padding: 20px;
273
+ overflow-y: auto;
274
+ }}
275
+
276
+ .message {{
277
+ margin-bottom: 15px;
278
+ padding: 10px;
279
+ border-radius: 5px;
280
+ max-width: 70%;
281
+ animation: slide-in 0.3s ease-out;
282
+ }}
283
+
284
+ .user-message {{
285
+ text-align: right;
286
+ background-color: #eee;
287
+ margin-left: 30%;
288
+ }}
289
+
290
+ .bot-message {{
291
+ text-align: left;
292
+ background-color: #ccf5ff;
293
+ margin-right: 30%;
294
+ }}
295
+
296
+ #input-area {{
297
+ display: flex;
298
+ padding: 10px;
299
+ background-color: #f8f9fa;
300
+ border-top: 1px solid #eee;
301
+ }}
302
+
303
+ #message-input {{
304
+ flex: 1;
305
+ padding: 10px;
306
+ border: 1px solid #ccc;
307
+ border-radius: 5px;
308
+ margin-right: 10px;
309
+ }}
310
+
311
+ #send-button {{
312
+ padding: 10px 15px;
313
+ background-color: #28a745;
314
+ color: white;
315
+ border: none;
316
+ cursor: pointer;
317
+ border-radius: 5px;
318
+ transition: background-color 0.3s ease;
319
+ }}
320
+
321
+ #send-button:hover {{
322
+ background-color: #218838;
323
+ }}
324
+
325
+ @keyframes slide-in {{
326
+ from {{
327
+ transform: translateX(-100%);
328
+ opacity: 0;
329
+ }}
330
+ to {{
331
+ transform: translateX(0);
332
+ opacity: 1;
333
+ }}
334
+ }}
335
  </style>
336
  </head>
337
  <body>
338
  <div class="container">
339
+ <h1>Chatbot</h1>
340
+ <div id="chatbox"></div>
341
+ <div id="input-area">
342
+ <input type="hidden" id="user-id" value="{user_id}">
343
+ <input type="text" id="message-input" placeholder="Escribe tu mensaje...">
344
+ <button id="send-button">Enviar</button>
345
+ </div>
346
  </div>
347
+ <script>
348
+ const chatbox = document.getElementById('chatbox');
349
+ const messageInput = document.getElementById('message-input');
350
+ const sendButton = document.getElementById('send-button');
351
+ const userId = document.getElementById('user-id').value;
352
+
353
+ sendButton.addEventListener('click', sendMessage);
354
+
355
+ function sendMessage() {{
356
+ const message = messageInput.value;
357
+ if (message.trim() === '') return;
358
+
359
+ appendMessage('user', message);
360
+ messageInput.value = '';
361
+
362
+ fetch('/process', {{
363
+ method: 'POST',
364
+ headers: {{
365
+ 'Content-Type': 'application/json'
366
+ }},
367
+ body: JSON.stringify({{ message: message, user_id: userId, language: 'es' }})
368
+ }})
369
+ .then(response => response.json())
370
+ .then(data => {{
371
+ appendMessage('bot', data.answer);
372
+ }});
373
+ }}
374
+
375
+ function appendMessage(sender, message) {{
376
+ const messageElement = document.createElement('div');
377
+ messageElement.classList.add('message', `${{sender}}-message`);
378
+ messageElement.textContent = message;
379
+ chatbox.appendChild(messageElement);
380
+ chatbox.scrollTop = chatbox.scrollHeight;
381
+ }}
382
+ </script>
383
  </body>
384
  </html>
385
  """
386
  return HTMLResponse(content=html_code)
387
 
388
  def push_to_redis(models, tokenizers, redis_client, model_name, tokenizer_name):
 
 
 
 
 
 
 
 
 
 
 
 
389
  for model_name, model in models.items():
390
  torch.save(model.state_dict(), model_name)
391
  redis_client.set(f"model:{model_name}", open(model_name, "rb").read())
 
395
  redis_client.set(f"tokenizer:{tokenizer_name}", json.dumps(tokens))
396
 
397
  def continuous_training():
 
 
 
398
  redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True)
399
 
400
  while True:
 
430
  time.sleep(5)
431
 
432
  if __name__ == "__main__":
 
433
  training_process = multiprocessing.Process(target=continuous_training)
434
  training_process.start()
435
 
 
436
  import uvicorn
437
  uvicorn.run(app, host="0.0.0.0", port=7860)