Train / main.py
Yjhhh's picture
Update main.py
f52a035 verified
from dotenv import load_dotenv
import os
import json
import redis
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
AutoModelForCausalLM,
TrainingArguments,
Trainer,
AutoModelForTextToWaveform,
pipeline,
)
from diffusers import FluxPipeline
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import HTMLResponse
import multiprocessing
import uuid
import torch
from torch.utils.data import Dataset
import numpy as np
load_dotenv()
REDIS_HOST = os.getenv('REDIS_HOST')
REDIS_PORT = os.getenv('REDIS_PORT')
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
app = FastAPI()
default_language = "es"
class ChatbotService:
def __init__(self):
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD)
self.model_name = "response_model"
self.tokenizer_name = "response_tokenizer"
self.model = self.load_model_from_redis()
self.tokenizer = self.load_tokenizer_from_redis()
def get_response(self, user_id, message, language=default_language):
if self.model is None or self.tokenizer is None:
return "El modelo aún no está listo. Por favor, inténtelo de nuevo más tarde."
input_text = f"Usuario: {message} Asistente:"
input_ids = self.tokenizer.encode(input_text, return_tensors="pt").to("cpu")
with torch.no_grad():
output = self.model.generate(input_ids=input_ids, max_length=100, num_beams=5, no_repeat_ngram_size=2, early_stopping=True)
response = self.tokenizer.decode(output[0], skip_special_tokens=True)
response = response.replace(input_text, "").strip()
return response
def load_model_from_redis(self):
model_data_bytes = self.redis_client.get(f"model:{self.model_name}")
if model_data_bytes:
model = AutoModelForCausalLM.from_pretrained("gpt2")
model.load_state_dict(torch.load(model_data_bytes))
return model
return None
def load_tokenizer_from_redis(self):
tokenizer_data_bytes = self.redis_client.get(f"tokenizer:{self.tokenizer_name}")
if tokenizer_data_bytes:
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.add_tokens(json.loads(tokenizer_data_bytes.decode("utf-8")))
tokenizer.pad_token = tokenizer.eos_token
return tokenizer
return None
chatbot_service = ChatbotService()
class UnifiedModel(AutoModelForSequenceClassification):
def __init__(self, config):
super().__init__(config)
@staticmethod
def load_model_from_redis(redis_client):
model_name = "unified_model"
model_path = f"models/{model_name}"
if redis_client.exists(f"model:{model_name}"):
redis_client.delete(f"model:{model_name}")
if not os.path.exists(model_path):
model = UnifiedModel.from_pretrained("gpt2", num_labels=3)
model.save_pretrained(model_path)
else:
model = UnifiedModel.from_pretrained(model_path)
return model
class SyntheticDataset(Dataset):
def __init__(self, tokenizer, data):
self.tokenizer = tokenizer
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
text = item['text']
label = item['label']
tokens = self.tokenizer(text, padding="max_length", truncation=True, max_length=128, return_tensors="pt")
return {"input_ids": tokens["input_ids"].squeeze(), "attention_mask": tokens["attention_mask"].squeeze(), "labels": label}
conversation_history = {}
tokenizer_name = "unified_tokenizer"
tokenizer = None
unified_model = None
musicgen_tokenizer = AutoTokenizer.from_pretrained("facebook/musicgen-small")
musicgen_model = AutoModelForTextToWaveform.from_pretrained("facebook/musicgen-small")
image_pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
image_pipeline.enable_model_cpu_offload()
@app.on_event("startup")
async def startup_event():
global tokenizer, unified_model
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD)
tokenizer_data_bytes = redis_client.get(f"tokenizer:{tokenizer_name}")
if tokenizer_data_bytes:
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.add_tokens(json.loads(tokenizer_data_bytes.decode("utf-8")))
tokenizer.pad_token = tokenizer.eos_token
else:
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
unified_model = UnifiedModel.load_model_from_redis(redis_client)
unified_model.to(torch.device("cpu"))
@app.post("/process")
async def process(request: Request):
global tokenizer, unified_model
data = await request.json()
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD)
if data.get("train"):
user_data = data.get("user_data", [])
if not user_data:
user_data = [
{"text": "Hola", "label": 1},
{"text": "Necesito ayuda", "label": 2},
{"text": "No entiendo", "label": 0}
]
redis_client.rpush("training_queue", json.dumps({
"tokenizers": {tokenizer_name: tokenizer.get_vocab()},
"data": user_data
}))
return {"message": "Training data received. Model will be updated asynchronously."}
elif data.get("message"):
user_id = data.get("user_id")
text = data['message']
language = data.get("language", default_language)
if user_id not in conversation_history:
conversation_history[user_id] = []
conversation_history[user_id].append(text)
contextualized_text = " ".join(conversation_history[user_id][-3:])
tokenized_input = tokenizer(contextualized_text, return_tensors="pt")
with torch.no_grad():
logits = unified_model(**tokenized_input).logits
predicted_class = torch.argmax(logits, dim=-1).item()
response = chatbot_service.get_response(user_id, contextualized_text, language)
redis_client.rpush("training_queue", json.dumps({
"tokenizers": {tokenizer_name: tokenizer.get_vocab()},
"data": [{"text": contextualized_text, "label": predicted_class}]
}))
return {"answer": response}
else:
raise HTTPException(status_code=400, detail="Request must contain 'train' or 'message'.")
@app.get("/")
async def get_home():
user_id = str(uuid.uuid4())
html_code = f"""
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<title>Chatbot</title>
<style>
body {{
font-family: 'Arial', sans-serif;
background-color: #f4f4f9;
margin: 0;
padding: 0;
display: flex;
align-items: center;
justify-content: center;
min-height: 100vh;
}}
.container {{
background-color: #fff;
border-radius: 10px;
box-shadow: 0 2px 5px rgba(0, 0, 0, 0.1);
overflow: hidden;
width: 400px;
max-width: 90%;
}}
h1 {{
color: #333;
text-align: center;
padding: 20px;
margin: 0;
background-color: #f8f9fa;
border-bottom: 1px solid #eee;
}}
#chatbox {{
height: 300px;
overflow-y: auto;
padding: 10px;
border-bottom: 1px solid #eee;
}}
.message {{
margin-bottom: 10px;
padding: 10px;
border-radius: 5px;
}}
.message.user {{
background-color: #e1f5fe;
text-align: right;
}}
.message.bot {{
background-color: #f1f1f1;
text-align: left;
}}
#input {{
display: flex;
padding: 10px;
}}
#input textarea {{
flex: 1;
padding: 10px;
border: 1px solid #ddd;
border-radius: 4px;
margin-right: 10px;
}}
#input button {{
padding: 10px 20px;
border: none;
border-radius: 4px;
background-color: #007bff;
color: #fff;
cursor: pointer;
}}
#input button:hover {{
background-color: #0056b3;
}}
</style>
</head>
<body>
<div class="container">
<h1>Chatbot</h1>
<div id="chatbox"></div>
<div id="input">
<textarea id="message" rows="3" placeholder="Escribe tu mensaje aquí..."></textarea>
<button id="send">Enviar</button>
</div>
</div>
<script>
const chatbox = document.getElementById('chatbox');
const messageInput = document.getElementById('message');
const sendButton = document.getElementById('send');
function appendMessage(text, sender) {{
const messageDiv = document.createElement('div');
messageDiv.classList.add('message', sender);
messageDiv.textContent = text;
chatbox.appendChild(messageDiv);
chatbox.scrollTop = chatbox.scrollHeight;
}}
async function sendMessage() {{
const message = messageInput.value;
if (!message.trim()) return;
appendMessage(message, 'user');
messageInput.value = '';
const response = await fetch('/process', {{
method: 'POST',
headers: {{
'Content-Type': 'application/json'
}},
body: JSON.stringify({{
message: message,
user_id: '{user_id}'
}})
}});
const data = await response.json();
appendMessage(data.answer, 'bot');
}}
sendButton.addEventListener('click', sendMessage);
messageInput.addEventListener('keypress', (e) => {{
if (e.key === 'Enter' && !e.shiftKey) {{
e.preventDefault();
sendMessage();
}}
}});
</script>
</body>
</html>
"""
return HTMLResponse(content=html_code)
def train_unified_model():
global tokenizer, unified_model
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD)
while True:
training_data = redis_client.lpop("training_queue")
if training_data:
item_data = json.loads(training_data)
tokenizer_data = item_data["tokenizers"]
tokenizer_name = list(tokenizer_data.keys())[0]
if redis_client.exists(f"tokenizer:{tokenizer_name}"):
tokenizer.add_tokens(list(tokenizer_data[tokenizer_name].keys()))
data = item_data["data"]
dataset = SyntheticDataset(tokenizer, data)
model_name = "unified_model"
model_path = f"models/{model_name}"
training_args = TrainingArguments(
output_dir="./results",
per_device_train_batch_size=8,
num_train_epochs=3,
)
trainer = Trainer(model=unified_model, args=training_args, train_dataset=dataset)
trainer.train()
unified_model.save_pretrained(model_path)
async def auto_learn():
global tokenizer, unified_model
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD)
while True:
training_data = redis_client.lpop("training_queue")
if training_data:
item_data = json.loads(training_data)
tokenizer_data = item_data["tokenizers"]
tokenizer_name = list(tokenizer_data.keys())[0]
if redis_client.exists(f"tokenizer:{tokenizer_name}"):
tokenizer.add_tokens(list(tokenizer_data[tokenizer_name].keys()))
data = item_data["data"]
dataset = SyntheticDataset(tokenizer, data)
model_name = "unified_model"
model_path = f"models/{model_name}"
training_args = TrainingArguments(
output_dir="./results",
per_device_train_batch_size=8,
num_train_epochs=3,
)
trainer = Trainer(model=unified_model, args=training_args, train_dataset=dataset)
trainer.train()
unified_model.save_pretrained(model_path)
async def auto_learn_music():
global musicgen_tokenizer, musicgen_model
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD)
while True:
music_training_data = redis_client.lpop("music_training_queue")
if music_training_data:
music_training_data = json.loads(music_training_data.decode("utf-8"))
inputs = musicgen_tokenizer(music_training_data, return_tensors="pt", padding=True)
musicgen_model.train()
optimizer = torch.optim.Adam(musicgen_model.parameters(), lr=5e-5)
loss_fn = torch.nn.CrossEntropyLoss()
for epoch in range(1):
outputs = musicgen_model(**inputs)
loss = loss_fn(outputs.logits, inputs['labels'])
optimizer.zero_grad()
loss.backward()
optimizer.step()
async def auto_learn_images():
global image_pipeline
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD)
while True:
image_training_data = redis_client.lpop("image_training_queue")
if image_training_data:
image_training_data = json.loads(image_training_data.decode("utf-8"))
for image_prompt in image_training_data:
image = image_pipeline(
image_prompt,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256,
generator=torch.Generator("cpu").manual_seed(0)
).images[0]
image_tensor = torch.tensor(np.array(image)).unsqueeze(0)
image_pipeline.model.train()
optimizer = torch.optim.Adam(image_pipeline.model.parameters(), lr=1e-5)
loss_fn = torch.nn.MSELoss()
target_tensor = torch.zeros_like(image_tensor)
for epoch in range(1):
outputs = image_pipeline.model(image_tensor)
loss = loss_fn(outputs, target_tensor)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if __name__ == "__main__":
training_process = multiprocessing.Process(target=train_unified_model)
training_process.start()
music_training_process = multiprocessing.Process(target=auto_learn_music)
music_training_process.start()
image_training_process = multiprocessing.Process(target=auto_learn_images)
image_training_process.start()
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)