Update main.py
Browse files
main.py
CHANGED
@@ -18,6 +18,8 @@ from typing import List, Dict
|
|
18 |
from fastapi.responses import HTMLResponse
|
19 |
import multiprocessing
|
20 |
import time
|
|
|
|
|
21 |
|
22 |
load_dotenv()
|
23 |
|
@@ -27,12 +29,74 @@ REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
|
|
27 |
|
28 |
app = FastAPI()
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
class UnifiedModel(nn.Module):
|
31 |
def __init__(self, models):
|
32 |
super(UnifiedModel, self).__init__()
|
33 |
self.models = nn.ModuleList(models)
|
34 |
hidden_size = self.models[0].config.hidden_size
|
35 |
-
self.classifier = nn.Linear(len(models) * hidden_size,
|
36 |
|
37 |
def forward(self, input_ids, attention_mask):
|
38 |
hidden_states = []
|
@@ -51,12 +115,13 @@ class UnifiedModel(nn.Module):
|
|
51 |
model_name = "unified_model"
|
52 |
model_data_bytes = redis_client.get(f"model:{model_name}")
|
53 |
if model_data_bytes:
|
54 |
-
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=
|
55 |
model.load_state_dict(torch.load(model_data_bytes))
|
56 |
else:
|
57 |
-
model = AutoModelForSequenceClassification.from_pretrained("gpt2")
|
58 |
return UnifiedModel([model])
|
59 |
|
|
|
60 |
class SyntheticDataset(Dataset):
|
61 |
def __init__(self, tokenizers, data):
|
62 |
self.tokenizers = tokenizers
|
@@ -77,22 +142,11 @@ class SyntheticDataset(Dataset):
|
|
77 |
tokenized["labels"] = torch.tensor(label)
|
78 |
return tokenized
|
79 |
|
|
|
|
|
|
|
80 |
@app.post("/process")
|
81 |
async def process(request: Request):
|
82 |
-
"""
|
83 |
-
Processes requests for training and prediction.
|
84 |
-
|
85 |
-
Args:
|
86 |
-
request (Request): The incoming request object.
|
87 |
-
|
88 |
-
Returns:
|
89 |
-
dict: A dictionary containing either a message indicating successful
|
90 |
-
training data submission or the model's prediction.
|
91 |
-
|
92 |
-
Raises:
|
93 |
-
HTTPException: If the request does not contain 'train' or 'predict'
|
94 |
-
keys.
|
95 |
-
"""
|
96 |
data = await request.json()
|
97 |
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True)
|
98 |
|
@@ -106,10 +160,10 @@ async def process(request: Request):
|
|
106 |
tokenizer_data_bytes = redis_client.get(f"tokenizer:{tokenizer_name}")
|
107 |
|
108 |
if model_data_bytes:
|
109 |
-
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=
|
110 |
model.load_state_dict(torch.load(model_data_bytes))
|
111 |
else:
|
112 |
-
model = AutoModelForSequenceClassification.from_pretrained("gpt2")
|
113 |
models[model_name] = model
|
114 |
|
115 |
if tokenizer_data_bytes:
|
@@ -125,9 +179,13 @@ async def process(request: Request):
|
|
125 |
if data.get("train"):
|
126 |
user_data = data.get("user_data", [])
|
127 |
if not user_data:
|
128 |
-
user_data = [
|
|
|
|
|
|
|
|
|
|
|
129 |
|
130 |
-
# Add user data to Redis queue for asynchronous training
|
131 |
redis_client.rpush("training_queue", json.dumps({
|
132 |
"tokenizers": {tokenizer_name: tokenizer.get_vocab()},
|
133 |
"data": user_data
|
@@ -135,9 +193,20 @@ async def process(request: Request):
|
|
135 |
|
136 |
return {"message": "Training data received. Model will be updated asynchronously."}
|
137 |
|
138 |
-
elif data.get("
|
139 |
-
|
140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
input_ids = [tokens['input_ids'] for tokens in tokenized_inputs]
|
142 |
attention_mask = [tokens['attention_mask'] for tokens in tokenized_inputs]
|
143 |
|
@@ -145,114 +214,178 @@ async def process(request: Request):
|
|
145 |
logits = unified_model(input_ids=input_ids, attention_mask=attention_mask)
|
146 |
predicted_class = torch.argmax(logits, dim=-1).item()
|
147 |
|
148 |
-
|
|
|
149 |
|
150 |
else:
|
151 |
-
raise HTTPException(status_code=400, detail="Request must contain 'train' or '
|
152 |
-
|
153 |
-
@app.post("/external_answer")
|
154 |
-
async def external_answer(request: Request):
|
155 |
-
"""
|
156 |
-
Provides an answer to a question using the unified model and triggers
|
157 |
-
asynchronous training with the new question-answer pair.
|
158 |
|
159 |
-
|
160 |
-
|
|
|
|
|
|
|
161 |
|
162 |
-
|
163 |
-
dict: A dictionary containing the answer to the question.
|
164 |
-
|
165 |
-
Raises:
|
166 |
-
HTTPException: If the request does not contain a 'question' key.
|
167 |
-
"""
|
168 |
-
data = await request.json()
|
169 |
-
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True)
|
170 |
-
|
171 |
-
question = data.get('question')
|
172 |
-
if not question:
|
173 |
-
raise HTTPException(status_code=400, detail="Question is required.")
|
174 |
-
|
175 |
-
# Load the model and tokenizer from Redis
|
176 |
-
unified_model = UnifiedModel.load_model_from_redis(redis_client)
|
177 |
-
unified_model.to(torch.device("cpu"))
|
178 |
-
|
179 |
-
tokenizer_data_bytes = redis_client.get(f"tokenizer:unified_tokenizer")
|
180 |
-
if tokenizer_data_bytes:
|
181 |
-
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
182 |
-
tokenizer.add_tokens(json.loads(tokenizer_data_bytes))
|
183 |
-
else:
|
184 |
-
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
185 |
-
|
186 |
-
tokenized_input = tokenizer(question, return_tensors="pt")
|
187 |
-
input_ids = tokenized_input['input_ids']
|
188 |
-
attention_mask = tokenized_input['attention_mask']
|
189 |
-
|
190 |
-
with torch.no_grad():
|
191 |
-
logits = unified_model(input_ids=[input_ids], attention_mask=[attention_mask])
|
192 |
-
predicted_class = torch.argmax(logits, dim=-1).item()
|
193 |
-
response = {"answer": f"Response to '{question}' is class {predicted_class}"}
|
194 |
-
|
195 |
-
# Asynchronously train on the new data point
|
196 |
-
redis_client.rpush("training_queue", json.dumps({
|
197 |
-
"tokenizers": {"unified_tokenizer": tokenizer.get_vocab()},
|
198 |
-
"data": [{"text": question, "label": predicted_class}]
|
199 |
-
}))
|
200 |
-
|
201 |
-
return response
|
202 |
|
203 |
@app.get("/")
|
204 |
async def get_home():
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
Returns:
|
209 |
-
HTMLResponse: The HTML content of the home page.
|
210 |
-
"""
|
211 |
-
html_code = """
|
212 |
<!DOCTYPE html>
|
213 |
<html>
|
214 |
<head>
|
215 |
<meta charset="UTF-8">
|
216 |
<title>Chatbot</title>
|
217 |
<style>
|
218 |
-
body {
|
219 |
-
font-family: Arial, sans-serif;
|
220 |
background-color: #f4f4f9;
|
221 |
margin: 0;
|
222 |
padding: 0;
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
color: #333;
|
231 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
232 |
</style>
|
233 |
</head>
|
234 |
<body>
|
235 |
<div class="container">
|
236 |
-
<h1>Chatbot
|
|
|
|
|
|
|
|
|
|
|
|
|
237 |
</div>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
238 |
</body>
|
239 |
</html>
|
240 |
"""
|
241 |
return HTMLResponse(content=html_code)
|
242 |
|
243 |
def push_to_redis(models, tokenizers, redis_client, model_name, tokenizer_name):
|
244 |
-
"""
|
245 |
-
Saves the given models and tokenizers to Redis.
|
246 |
-
|
247 |
-
Args:
|
248 |
-
models (dict): A dictionary of model names and their corresponding
|
249 |
-
PyTorch models.
|
250 |
-
tokenizers (dict): A dictionary of tokenizer names and their
|
251 |
-
corresponding tokenizers.
|
252 |
-
redis_client: The Redis client instance.
|
253 |
-
model_name (str): The base name to use for saving the models.
|
254 |
-
tokenizer_name (str): The base name to use for saving the tokenizers.
|
255 |
-
"""
|
256 |
for model_name, model in models.items():
|
257 |
torch.save(model.state_dict(), model_name)
|
258 |
redis_client.set(f"model:{model_name}", open(model_name, "rb").read())
|
@@ -262,9 +395,6 @@ def push_to_redis(models, tokenizers, redis_client, model_name, tokenizer_name):
|
|
262 |
redis_client.set(f"tokenizer:{tokenizer_name}", json.dumps(tokens))
|
263 |
|
264 |
def continuous_training():
|
265 |
-
"""
|
266 |
-
Continuously checks for new training data in Redis and updates the model.
|
267 |
-
"""
|
268 |
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True)
|
269 |
|
270 |
while True:
|
@@ -300,10 +430,8 @@ def continuous_training():
|
|
300 |
time.sleep(5)
|
301 |
|
302 |
if __name__ == "__main__":
|
303 |
-
# Start the continuous training process in a separate process
|
304 |
training_process = multiprocessing.Process(target=continuous_training)
|
305 |
training_process.start()
|
306 |
|
307 |
-
# Run the FastAPI app
|
308 |
import uvicorn
|
309 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
18 |
from fastapi.responses import HTMLResponse
|
19 |
import multiprocessing
|
20 |
import time
|
21 |
+
import uuid
|
22 |
+
import random
|
23 |
|
24 |
load_dotenv()
|
25 |
|
|
|
29 |
|
30 |
app = FastAPI()
|
31 |
|
32 |
+
# Configuración de idioma
|
33 |
+
language_responses = {
|
34 |
+
"es": {
|
35 |
+
0: [
|
36 |
+
"Lo siento, no entiendo.",
|
37 |
+
"No estoy seguro de entender lo que quieres decir.",
|
38 |
+
"¿Podrías reformular tu pregunta?"
|
39 |
+
],
|
40 |
+
1: [
|
41 |
+
"Hola! ¿Cómo estás?",
|
42 |
+
"¡Hola! ¿Qué tal?",
|
43 |
+
"Buenos días/tardes/noches, ¿cómo te va?"
|
44 |
+
],
|
45 |
+
2: [
|
46 |
+
"¿Cómo te puedo ayudar?",
|
47 |
+
"¿En qué puedo ayudarte?",
|
48 |
+
"Dime, ¿qué necesitas?"
|
49 |
+
],
|
50 |
+
# ... más respuestas para otras clases
|
51 |
+
},
|
52 |
+
"en": {
|
53 |
+
0: [
|
54 |
+
"Sorry, I don't understand.",
|
55 |
+
"I'm not sure I understand what you mean.",
|
56 |
+
"Could you rephrase your question?"
|
57 |
+
],
|
58 |
+
1: [
|
59 |
+
"Hello! How are you?",
|
60 |
+
"Hi! What's up?",
|
61 |
+
"Good morning/afternoon/evening, how are you doing?"
|
62 |
+
],
|
63 |
+
2: [
|
64 |
+
"How can I help you?",
|
65 |
+
"What can I do for you?",
|
66 |
+
"Tell me, what do you need?"
|
67 |
+
],
|
68 |
+
# ... más respuestas para otras clases
|
69 |
+
}
|
70 |
+
}
|
71 |
+
|
72 |
+
default_language = "es" # Idioma predeterminado
|
73 |
+
|
74 |
+
# Servicio de Chatbot
|
75 |
+
class ChatbotService:
|
76 |
+
def get_response(self, user_id, message, predicted_class, language=default_language):
|
77 |
+
# Accede al diccionario de respuestas según el idioma
|
78 |
+
responses = language_responses.get(language, language_responses["es"])
|
79 |
+
|
80 |
+
# Lógica basada en la clase predicha
|
81 |
+
if predicted_class == 1:
|
82 |
+
# Seleccionar una respuesta de saludo aleatoria
|
83 |
+
return random.choice(responses[1])
|
84 |
+
elif predicted_class == 2:
|
85 |
+
# Seleccionar una respuesta de ayuda aleatoria
|
86 |
+
return random.choice(responses[2])
|
87 |
+
else:
|
88 |
+
# Seleccionar una respuesta de no comprensión aleatoria
|
89 |
+
return random.choice(responses[0])
|
90 |
+
|
91 |
+
chatbot_service = ChatbotService()
|
92 |
+
|
93 |
+
# Modelo de clasificación de texto
|
94 |
class UnifiedModel(nn.Module):
|
95 |
def __init__(self, models):
|
96 |
super(UnifiedModel, self).__init__()
|
97 |
self.models = nn.ModuleList(models)
|
98 |
hidden_size = self.models[0].config.hidden_size
|
99 |
+
self.classifier = nn.Linear(len(models) * hidden_size, 3) # 3 clases
|
100 |
|
101 |
def forward(self, input_ids, attention_mask):
|
102 |
hidden_states = []
|
|
|
115 |
model_name = "unified_model"
|
116 |
model_data_bytes = redis_client.get(f"model:{model_name}")
|
117 |
if model_data_bytes:
|
118 |
+
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3) # 3 clases
|
119 |
model.load_state_dict(torch.load(model_data_bytes))
|
120 |
else:
|
121 |
+
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3) # 3 clases
|
122 |
return UnifiedModel([model])
|
123 |
|
124 |
+
# Dataset para entrenamiento
|
125 |
class SyntheticDataset(Dataset):
|
126 |
def __init__(self, tokenizers, data):
|
127 |
self.tokenizers = tokenizers
|
|
|
142 |
tokenized["labels"] = torch.tensor(label)
|
143 |
return tokenized
|
144 |
|
145 |
+
# Manejo de la conversación
|
146 |
+
conversation_history = {}
|
147 |
+
|
148 |
@app.post("/process")
|
149 |
async def process(request: Request):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
data = await request.json()
|
151 |
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True)
|
152 |
|
|
|
160 |
tokenizer_data_bytes = redis_client.get(f"tokenizer:{tokenizer_name}")
|
161 |
|
162 |
if model_data_bytes:
|
163 |
+
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3) # 3 clases
|
164 |
model.load_state_dict(torch.load(model_data_bytes))
|
165 |
else:
|
166 |
+
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3) # 3 clases
|
167 |
models[model_name] = model
|
168 |
|
169 |
if tokenizer_data_bytes:
|
|
|
179 |
if data.get("train"):
|
180 |
user_data = data.get("user_data", [])
|
181 |
if not user_data:
|
182 |
+
user_data = [
|
183 |
+
{"text": "Hola", "label": 1},
|
184 |
+
{"text": "Necesito ayuda", "label": 2},
|
185 |
+
{"text": "No entiendo", "label": 0}
|
186 |
+
# ... más ejemplos para otras clases
|
187 |
+
]
|
188 |
|
|
|
189 |
redis_client.rpush("training_queue", json.dumps({
|
190 |
"tokenizers": {tokenizer_name: tokenizer.get_vocab()},
|
191 |
"data": user_data
|
|
|
193 |
|
194 |
return {"message": "Training data received. Model will be updated asynchronously."}
|
195 |
|
196 |
+
elif data.get("message"):
|
197 |
+
user_id = data.get("user_id")
|
198 |
+
text = data['message']
|
199 |
+
language = data.get("language", default_language)
|
200 |
+
|
201 |
+
# Memoria de Conversación
|
202 |
+
if user_id not in conversation_history:
|
203 |
+
conversation_history[user_id] = []
|
204 |
+
conversation_history[user_id].append(text)
|
205 |
+
|
206 |
+
# Concatenar el historial al mensaje actual (puedes usar otra técnica)
|
207 |
+
contextualized_text = " ".join(conversation_history[user_id][-3:]) # Usar los últimos 3 mensajes
|
208 |
+
|
209 |
+
tokenized_inputs = [tokenizers[name](contextualized_text, return_tensors="pt") for name in tokenizers.keys()]
|
210 |
input_ids = [tokens['input_ids'] for tokens in tokenized_inputs]
|
211 |
attention_mask = [tokens['attention_mask'] for tokens in tokenized_inputs]
|
212 |
|
|
|
214 |
logits = unified_model(input_ids=input_ids, attention_mask=attention_mask)
|
215 |
predicted_class = torch.argmax(logits, dim=-1).item()
|
216 |
|
217 |
+
response = get_chatbot_response(user_id, text, predicted_class, language)
|
218 |
+
return {"answer": response}
|
219 |
|
220 |
else:
|
221 |
+
raise HTTPException(status_code=400, detail="Request must contain 'train' or 'message'.")
|
|
|
|
|
|
|
|
|
|
|
|
|
222 |
|
223 |
+
def get_chatbot_response(user_id, question, predicted_class, language):
|
224 |
+
# Almacenar el mensaje en el historial
|
225 |
+
if user_id not in conversation_history:
|
226 |
+
conversation_history[user_id] = []
|
227 |
+
conversation_history[user_id].append(question)
|
228 |
|
229 |
+
return chatbot_service.get_response(user_id, question, predicted_class, language)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
|
231 |
@app.get("/")
|
232 |
async def get_home():
|
233 |
+
user_id = str(uuid.uuid4())
|
234 |
+
html_code = f"""
|
|
|
|
|
|
|
|
|
|
|
235 |
<!DOCTYPE html>
|
236 |
<html>
|
237 |
<head>
|
238 |
<meta charset="UTF-8">
|
239 |
<title>Chatbot</title>
|
240 |
<style>
|
241 |
+
body {{
|
242 |
+
font-family: 'Arial', sans-serif;
|
243 |
background-color: #f4f4f9;
|
244 |
margin: 0;
|
245 |
padding: 0;
|
246 |
+
display: flex;
|
247 |
+
align-items: center;
|
248 |
+
justify-content: center;
|
249 |
+
min-height: 100vh;
|
250 |
+
}}
|
251 |
+
|
252 |
+
.container {{
|
253 |
+
background-color: #fff;
|
254 |
+
border-radius: 10px;
|
255 |
+
box-shadow: 0 2px 5px rgba(0, 0, 0, 0.1);
|
256 |
+
overflow: hidden;
|
257 |
+
width: 400px;
|
258 |
+
max-width: 90%;
|
259 |
+
}}
|
260 |
+
|
261 |
+
h1 {{
|
262 |
color: #333;
|
263 |
+
text-align: center;
|
264 |
+
padding: 20px;
|
265 |
+
margin: 0;
|
266 |
+
background-color: #f8f9fa;
|
267 |
+
border-bottom: 1px solid #eee;
|
268 |
+
}}
|
269 |
+
|
270 |
+
#chatbox {{
|
271 |
+
height: 400px;
|
272 |
+
padding: 20px;
|
273 |
+
overflow-y: auto;
|
274 |
+
}}
|
275 |
+
|
276 |
+
.message {{
|
277 |
+
margin-bottom: 15px;
|
278 |
+
padding: 10px;
|
279 |
+
border-radius: 5px;
|
280 |
+
max-width: 70%;
|
281 |
+
animation: slide-in 0.3s ease-out;
|
282 |
+
}}
|
283 |
+
|
284 |
+
.user-message {{
|
285 |
+
text-align: right;
|
286 |
+
background-color: #eee;
|
287 |
+
margin-left: 30%;
|
288 |
+
}}
|
289 |
+
|
290 |
+
.bot-message {{
|
291 |
+
text-align: left;
|
292 |
+
background-color: #ccf5ff;
|
293 |
+
margin-right: 30%;
|
294 |
+
}}
|
295 |
+
|
296 |
+
#input-area {{
|
297 |
+
display: flex;
|
298 |
+
padding: 10px;
|
299 |
+
background-color: #f8f9fa;
|
300 |
+
border-top: 1px solid #eee;
|
301 |
+
}}
|
302 |
+
|
303 |
+
#message-input {{
|
304 |
+
flex: 1;
|
305 |
+
padding: 10px;
|
306 |
+
border: 1px solid #ccc;
|
307 |
+
border-radius: 5px;
|
308 |
+
margin-right: 10px;
|
309 |
+
}}
|
310 |
+
|
311 |
+
#send-button {{
|
312 |
+
padding: 10px 15px;
|
313 |
+
background-color: #28a745;
|
314 |
+
color: white;
|
315 |
+
border: none;
|
316 |
+
cursor: pointer;
|
317 |
+
border-radius: 5px;
|
318 |
+
transition: background-color 0.3s ease;
|
319 |
+
}}
|
320 |
+
|
321 |
+
#send-button:hover {{
|
322 |
+
background-color: #218838;
|
323 |
+
}}
|
324 |
+
|
325 |
+
@keyframes slide-in {{
|
326 |
+
from {{
|
327 |
+
transform: translateX(-100%);
|
328 |
+
opacity: 0;
|
329 |
+
}}
|
330 |
+
to {{
|
331 |
+
transform: translateX(0);
|
332 |
+
opacity: 1;
|
333 |
+
}}
|
334 |
+
}}
|
335 |
</style>
|
336 |
</head>
|
337 |
<body>
|
338 |
<div class="container">
|
339 |
+
<h1>Chatbot</h1>
|
340 |
+
<div id="chatbox"></div>
|
341 |
+
<div id="input-area">
|
342 |
+
<input type="hidden" id="user-id" value="{user_id}">
|
343 |
+
<input type="text" id="message-input" placeholder="Escribe tu mensaje...">
|
344 |
+
<button id="send-button">Enviar</button>
|
345 |
+
</div>
|
346 |
</div>
|
347 |
+
<script>
|
348 |
+
const chatbox = document.getElementById('chatbox');
|
349 |
+
const messageInput = document.getElementById('message-input');
|
350 |
+
const sendButton = document.getElementById('send-button');
|
351 |
+
const userId = document.getElementById('user-id').value;
|
352 |
+
|
353 |
+
sendButton.addEventListener('click', sendMessage);
|
354 |
+
|
355 |
+
function sendMessage() {{
|
356 |
+
const message = messageInput.value;
|
357 |
+
if (message.trim() === '') return;
|
358 |
+
|
359 |
+
appendMessage('user', message);
|
360 |
+
messageInput.value = '';
|
361 |
+
|
362 |
+
fetch('/process', {{
|
363 |
+
method: 'POST',
|
364 |
+
headers: {{
|
365 |
+
'Content-Type': 'application/json'
|
366 |
+
}},
|
367 |
+
body: JSON.stringify({{ message: message, user_id: userId, language: 'es' }})
|
368 |
+
}})
|
369 |
+
.then(response => response.json())
|
370 |
+
.then(data => {{
|
371 |
+
appendMessage('bot', data.answer);
|
372 |
+
}});
|
373 |
+
}}
|
374 |
+
|
375 |
+
function appendMessage(sender, message) {{
|
376 |
+
const messageElement = document.createElement('div');
|
377 |
+
messageElement.classList.add('message', `${{sender}}-message`);
|
378 |
+
messageElement.textContent = message;
|
379 |
+
chatbox.appendChild(messageElement);
|
380 |
+
chatbox.scrollTop = chatbox.scrollHeight;
|
381 |
+
}}
|
382 |
+
</script>
|
383 |
</body>
|
384 |
</html>
|
385 |
"""
|
386 |
return HTMLResponse(content=html_code)
|
387 |
|
388 |
def push_to_redis(models, tokenizers, redis_client, model_name, tokenizer_name):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
389 |
for model_name, model in models.items():
|
390 |
torch.save(model.state_dict(), model_name)
|
391 |
redis_client.set(f"model:{model_name}", open(model_name, "rb").read())
|
|
|
395 |
redis_client.set(f"tokenizer:{tokenizer_name}", json.dumps(tokens))
|
396 |
|
397 |
def continuous_training():
|
|
|
|
|
|
|
398 |
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True)
|
399 |
|
400 |
while True:
|
|
|
430 |
time.sleep(5)
|
431 |
|
432 |
if __name__ == "__main__":
|
|
|
433 |
training_process = multiprocessing.Process(target=continuous_training)
|
434 |
training_process.start()
|
435 |
|
|
|
436 |
import uvicorn
|
437 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|