|
from dotenv import load_dotenv |
|
import os |
|
import json |
|
import redis |
|
from transformers import ( |
|
AutoTokenizer, |
|
AutoModelForSequenceClassification, |
|
AutoModelForCausalLM, |
|
TrainingArguments, |
|
Trainer, |
|
AutoModelForTextToWaveform, |
|
pipeline, |
|
) |
|
from diffusers import FluxPipeline |
|
from fastapi import FastAPI, HTTPException, Request |
|
from fastapi.responses import HTMLResponse |
|
import multiprocessing |
|
import uuid |
|
import torch |
|
from torch.utils.data import Dataset |
|
import numpy as np |
|
|
|
load_dotenv() |
|
|
|
REDIS_HOST = os.getenv('REDIS_HOST') |
|
REDIS_PORT = os.getenv('REDIS_PORT') |
|
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD') |
|
|
|
app = FastAPI() |
|
|
|
default_language = "es" |
|
|
|
class ChatbotService: |
|
def __init__(self): |
|
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD) |
|
self.model_name = "response_model" |
|
self.tokenizer_name = "response_tokenizer" |
|
self.model = self.load_model_from_redis() |
|
self.tokenizer = self.load_tokenizer_from_redis() |
|
|
|
def get_response(self, user_id, message, language=default_language): |
|
if self.model is None or self.tokenizer is None: |
|
return "El modelo aún no está listo. Por favor, inténtelo de nuevo más tarde." |
|
input_text = f"Usuario: {message} Asistente:" |
|
input_ids = self.tokenizer.encode(input_text, return_tensors="pt").to("cpu") |
|
with torch.no_grad(): |
|
output = self.model.generate(input_ids=input_ids, max_length=100, num_beams=5, no_repeat_ngram_size=2, early_stopping=True) |
|
response = self.tokenizer.decode(output[0], skip_special_tokens=True) |
|
response = response.replace(input_text, "").strip() |
|
return response |
|
|
|
def load_model_from_redis(self): |
|
model_data_bytes = self.redis_client.get(f"model:{self.model_name}") |
|
if model_data_bytes: |
|
model = AutoModelForCausalLM.from_pretrained("gpt2") |
|
model.load_state_dict(torch.load(model_data_bytes)) |
|
return model |
|
return None |
|
|
|
def load_tokenizer_from_redis(self): |
|
tokenizer_data_bytes = self.redis_client.get(f"tokenizer:{self.tokenizer_name}") |
|
if tokenizer_data_bytes: |
|
tokenizer = AutoTokenizer.from_pretrained("gpt2") |
|
tokenizer.add_tokens(json.loads(tokenizer_data_bytes.decode("utf-8"))) |
|
tokenizer.pad_token = tokenizer.eos_token |
|
return tokenizer |
|
return None |
|
|
|
chatbot_service = ChatbotService() |
|
|
|
class UnifiedModel(AutoModelForSequenceClassification): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
@staticmethod |
|
def load_model_from_redis(redis_client): |
|
model_name = "unified_model" |
|
model_path = f"models/{model_name}" |
|
if redis_client.exists(f"model:{model_name}"): |
|
redis_client.delete(f"model:{model_name}") |
|
if not os.path.exists(model_path): |
|
model = UnifiedModel.from_pretrained("gpt2", num_labels=3) |
|
model.save_pretrained(model_path) |
|
else: |
|
model = UnifiedModel.from_pretrained(model_path) |
|
return model |
|
|
|
class SyntheticDataset(Dataset): |
|
def __init__(self, tokenizer, data): |
|
self.tokenizer = tokenizer |
|
self.data = data |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
item = self.data[idx] |
|
text = item['text'] |
|
label = item['label'] |
|
tokens = self.tokenizer(text, padding="max_length", truncation=True, max_length=128, return_tensors="pt") |
|
return {"input_ids": tokens["input_ids"].squeeze(), "attention_mask": tokens["attention_mask"].squeeze(), "labels": label} |
|
|
|
conversation_history = {} |
|
|
|
tokenizer_name = "unified_tokenizer" |
|
tokenizer = None |
|
unified_model = None |
|
musicgen_tokenizer = AutoTokenizer.from_pretrained("facebook/musicgen-small") |
|
musicgen_model = AutoModelForTextToWaveform.from_pretrained("facebook/musicgen-small") |
|
image_pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) |
|
image_pipeline.enable_model_cpu_offload() |
|
|
|
@app.on_event("startup") |
|
async def startup_event(): |
|
global tokenizer, unified_model |
|
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD) |
|
tokenizer_data_bytes = redis_client.get(f"tokenizer:{tokenizer_name}") |
|
if tokenizer_data_bytes: |
|
tokenizer = AutoTokenizer.from_pretrained("gpt2") |
|
tokenizer.add_tokens(json.loads(tokenizer_data_bytes.decode("utf-8"))) |
|
tokenizer.pad_token = tokenizer.eos_token |
|
else: |
|
tokenizer = AutoTokenizer.from_pretrained("gpt2") |
|
tokenizer.pad_token = tokenizer.eos_token |
|
unified_model = UnifiedModel.load_model_from_redis(redis_client) |
|
unified_model.to(torch.device("cpu")) |
|
|
|
@app.post("/process") |
|
async def process(request: Request): |
|
global tokenizer, unified_model |
|
data = await request.json() |
|
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD) |
|
|
|
if data.get("train"): |
|
user_data = data.get("user_data", []) |
|
if not user_data: |
|
user_data = [ |
|
{"text": "Hola", "label": 1}, |
|
{"text": "Necesito ayuda", "label": 2}, |
|
{"text": "No entiendo", "label": 0} |
|
] |
|
redis_client.rpush("training_queue", json.dumps({ |
|
"tokenizers": {tokenizer_name: tokenizer.get_vocab()}, |
|
"data": user_data |
|
})) |
|
return {"message": "Training data received. Model will be updated asynchronously."} |
|
elif data.get("message"): |
|
user_id = data.get("user_id") |
|
text = data['message'] |
|
language = data.get("language", default_language) |
|
if user_id not in conversation_history: |
|
conversation_history[user_id] = [] |
|
conversation_history[user_id].append(text) |
|
contextualized_text = " ".join(conversation_history[user_id][-3:]) |
|
tokenized_input = tokenizer(contextualized_text, return_tensors="pt") |
|
with torch.no_grad(): |
|
logits = unified_model(**tokenized_input).logits |
|
predicted_class = torch.argmax(logits, dim=-1).item() |
|
response = chatbot_service.get_response(user_id, contextualized_text, language) |
|
redis_client.rpush("training_queue", json.dumps({ |
|
"tokenizers": {tokenizer_name: tokenizer.get_vocab()}, |
|
"data": [{"text": contextualized_text, "label": predicted_class}] |
|
})) |
|
return {"answer": response} |
|
else: |
|
raise HTTPException(status_code=400, detail="Request must contain 'train' or 'message'.") |
|
|
|
@app.get("/") |
|
async def get_home(): |
|
user_id = str(uuid.uuid4()) |
|
html_code = f""" |
|
<!DOCTYPE html> |
|
<html> |
|
<head> |
|
<meta charset="UTF-8"> |
|
<title>Chatbot</title> |
|
<style> |
|
body {{ |
|
font-family: 'Arial', sans-serif; |
|
background-color: #f4f4f9; |
|
margin: 0; |
|
padding: 0; |
|
display: flex; |
|
align-items: center; |
|
justify-content: center; |
|
min-height: 100vh; |
|
}} |
|
.container {{ |
|
background-color: #fff; |
|
border-radius: 10px; |
|
box-shadow: 0 2px 5px rgba(0, 0, 0, 0.1); |
|
overflow: hidden; |
|
width: 400px; |
|
max-width: 90%; |
|
}} |
|
h1 {{ |
|
color: #333; |
|
text-align: center; |
|
padding: 20px; |
|
margin: 0; |
|
background-color: #f8f9fa; |
|
border-bottom: 1px solid #eee; |
|
}} |
|
#chatbox {{ |
|
height: 300px; |
|
overflow-y: auto; |
|
padding: 10px; |
|
border-bottom: 1px solid #eee; |
|
}} |
|
.message {{ |
|
margin-bottom: 10px; |
|
padding: 10px; |
|
border-radius: 5px; |
|
}} |
|
.message.user {{ |
|
background-color: #e1f5fe; |
|
text-align: right; |
|
}} |
|
.message.bot {{ |
|
background-color: #f1f1f1; |
|
text-align: left; |
|
}} |
|
#input {{ |
|
display: flex; |
|
padding: 10px; |
|
}} |
|
#input textarea {{ |
|
flex: 1; |
|
padding: 10px; |
|
border: 1px solid #ddd; |
|
border-radius: 4px; |
|
margin-right: 10px; |
|
}} |
|
#input button {{ |
|
padding: 10px 20px; |
|
border: none; |
|
border-radius: 4px; |
|
background-color: #007bff; |
|
color: #fff; |
|
cursor: pointer; |
|
}} |
|
#input button:hover {{ |
|
background-color: #0056b3; |
|
}} |
|
</style> |
|
</head> |
|
<body> |
|
<div class="container"> |
|
<h1>Chatbot</h1> |
|
<div id="chatbox"></div> |
|
<div id="input"> |
|
<textarea id="message" rows="3" placeholder="Escribe tu mensaje aquí..."></textarea> |
|
<button id="send">Enviar</button> |
|
</div> |
|
</div> |
|
<script> |
|
const chatbox = document.getElementById('chatbox'); |
|
const messageInput = document.getElementById('message'); |
|
const sendButton = document.getElementById('send'); |
|
|
|
function appendMessage(text, sender) {{ |
|
const messageDiv = document.createElement('div'); |
|
messageDiv.classList.add('message', sender); |
|
messageDiv.textContent = text; |
|
chatbox.appendChild(messageDiv); |
|
chatbox.scrollTop = chatbox.scrollHeight; |
|
}} |
|
|
|
async function sendMessage() {{ |
|
const message = messageInput.value; |
|
if (!message.trim()) return; |
|
|
|
appendMessage(message, 'user'); |
|
messageInput.value = ''; |
|
|
|
const response = await fetch('/process', {{ |
|
method: 'POST', |
|
headers: {{ |
|
'Content-Type': 'application/json' |
|
}}, |
|
body: JSON.stringify({{ |
|
message: message, |
|
user_id: '{user_id}' |
|
}}) |
|
}}); |
|
const data = await response.json(); |
|
appendMessage(data.answer, 'bot'); |
|
}} |
|
|
|
sendButton.addEventListener('click', sendMessage); |
|
messageInput.addEventListener('keypress', (e) => {{ |
|
if (e.key === 'Enter' && !e.shiftKey) {{ |
|
e.preventDefault(); |
|
sendMessage(); |
|
}} |
|
}}); |
|
</script> |
|
</body> |
|
</html> |
|
""" |
|
return HTMLResponse(content=html_code) |
|
|
|
def train_unified_model(): |
|
global tokenizer, unified_model |
|
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD) |
|
while True: |
|
training_data = redis_client.lpop("training_queue") |
|
if training_data: |
|
item_data = json.loads(training_data) |
|
tokenizer_data = item_data["tokenizers"] |
|
tokenizer_name = list(tokenizer_data.keys())[0] |
|
if redis_client.exists(f"tokenizer:{tokenizer_name}"): |
|
tokenizer.add_tokens(list(tokenizer_data[tokenizer_name].keys())) |
|
data = item_data["data"] |
|
dataset = SyntheticDataset(tokenizer, data) |
|
|
|
model_name = "unified_model" |
|
model_path = f"models/{model_name}" |
|
|
|
training_args = TrainingArguments( |
|
output_dir="./results", |
|
per_device_train_batch_size=8, |
|
num_train_epochs=3, |
|
) |
|
trainer = Trainer(model=unified_model, args=training_args, train_dataset=dataset) |
|
trainer.train() |
|
unified_model.save_pretrained(model_path) |
|
|
|
async def auto_learn(): |
|
global tokenizer, unified_model |
|
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD) |
|
while True: |
|
training_data = redis_client.lpop("training_queue") |
|
if training_data: |
|
item_data = json.loads(training_data) |
|
tokenizer_data = item_data["tokenizers"] |
|
tokenizer_name = list(tokenizer_data.keys())[0] |
|
if redis_client.exists(f"tokenizer:{tokenizer_name}"): |
|
tokenizer.add_tokens(list(tokenizer_data[tokenizer_name].keys())) |
|
data = item_data["data"] |
|
dataset = SyntheticDataset(tokenizer, data) |
|
|
|
model_name = "unified_model" |
|
model_path = f"models/{model_name}" |
|
|
|
training_args = TrainingArguments( |
|
output_dir="./results", |
|
per_device_train_batch_size=8, |
|
num_train_epochs=3, |
|
) |
|
trainer = Trainer(model=unified_model, args=training_args, train_dataset=dataset) |
|
trainer.train() |
|
unified_model.save_pretrained(model_path) |
|
|
|
async def auto_learn_music(): |
|
global musicgen_tokenizer, musicgen_model |
|
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD) |
|
while True: |
|
music_training_data = redis_client.lpop("music_training_queue") |
|
if music_training_data: |
|
music_training_data = json.loads(music_training_data.decode("utf-8")) |
|
inputs = musicgen_tokenizer(music_training_data, return_tensors="pt", padding=True) |
|
musicgen_model.train() |
|
optimizer = torch.optim.Adam(musicgen_model.parameters(), lr=5e-5) |
|
loss_fn = torch.nn.CrossEntropyLoss() |
|
|
|
for epoch in range(1): |
|
outputs = musicgen_model(**inputs) |
|
loss = loss_fn(outputs.logits, inputs['labels']) |
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
|
|
async def auto_learn_images(): |
|
global image_pipeline |
|
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD) |
|
while True: |
|
image_training_data = redis_client.lpop("image_training_queue") |
|
if image_training_data: |
|
image_training_data = json.loads(image_training_data.decode("utf-8")) |
|
for image_prompt in image_training_data: |
|
image = image_pipeline( |
|
image_prompt, |
|
guidance_scale=0.0, |
|
num_inference_steps=4, |
|
max_sequence_length=256, |
|
generator=torch.Generator("cpu").manual_seed(0) |
|
).images[0] |
|
image_tensor = torch.tensor(np.array(image)).unsqueeze(0) |
|
image_pipeline.model.train() |
|
optimizer = torch.optim.Adam(image_pipeline.model.parameters(), lr=1e-5) |
|
loss_fn = torch.nn.MSELoss() |
|
target_tensor = torch.zeros_like(image_tensor) |
|
for epoch in range(1): |
|
outputs = image_pipeline.model(image_tensor) |
|
loss = loss_fn(outputs, target_tensor) |
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
|
|
|
|
if __name__ == "__main__": |
|
training_process = multiprocessing.Process(target=train_unified_model) |
|
training_process.start() |
|
music_training_process = multiprocessing.Process(target=auto_learn_music) |
|
music_training_process.start() |
|
image_training_process = multiprocessing.Process(target=auto_learn_images) |
|
image_training_process.start() |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |