Train / main.py
Yjhhh's picture
Update main.py
4cfc38d verified
raw
history blame
16 kB
from dotenv import load_dotenv
import os
import json
import requests
import redis
from transformers import (
AutoTokenizer,
AutoModel,
TrainingArguments,
)
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
from fastapi import FastAPI, HTTPException, Request
from pydantic import BaseModel
from typing import List, Dict
from fastapi.responses import HTMLResponse
import multiprocessing
import time
load_dotenv()
REDIS_HOST = os.getenv('REDIS_HOST')
REDIS_PORT = os.getenv('REDIS_PORT')
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
app = FastAPI()
class UnifiedModel(nn.Module):
def __init__(self, models):
super(UnifiedModel, self).__init__()
self.models = nn.ModuleList(models)
self.classifier = nn.Linear(sum([model.config.hidden_size for model in models]), 2)
def forward(self, input_ids, attention_mask):
hidden_states = []
for model, input_id, attn_mask in zip(self.models, input_ids, attention_mask):
outputs = model(
input_ids=input_id,
attention_mask=attn_mask
)
hidden_states.append(outputs.last_hidden_state[:, 0, :])
concatenated_hidden_states = torch.cat(hidden_states, dim=-1)
logits = self.classifier(concatenated_hidden_states)
return logits
class SyntheticDataset(Dataset):
def __init__(self, tokenizers, data):
self.tokenizers = tokenizers
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
text = item['text']
label = item['label']
tokenized = {}
for name, tokenizer in self.tokenizers.items():
tokens = tokenizer(text, padding="max_length", truncation=True, max_length=128)
tokenized[f"input_ids_{name}"] = torch.tensor(tokens["input_ids"])
tokenized[f"attention_mask_{name}"] = torch.tensor(tokens["attention_mask"])
tokenized["label"] = torch.tensor(label)
return tokenized
@app.post("/process")
async def process(request: Request):
data = await request.json()
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True)
tokenizers = {}
models = {}
model_name = "unified_model"
tokenizer_name = "unified_tokenizer"
model_data_bytes = redis_client.get(f"model:{model_name}")
tokenizer_data_bytes = redis_client.get(f"tokenizer:{tokenizer_name}")
if model_data_bytes:
model_data = json.loads(model_data_bytes)
model = AutoModel.from_pretrained("gpt2")
model.load_state_dict(torch.load(model_data))
else:
model = AutoModel.from_pretrained("gpt2")
models[model_name] = model
if tokenizer_data_bytes:
tokenizer_data = json.loads(tokenizer_data_bytes)
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.add_tokens(tokenizer_data)
else:
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizers[tokenizer_name] = tokenizer
unified_model = UnifiedModel(list(models.values()))
unified_model.to(torch.device("cpu"))
if data.get("train"):
user_data = data.get("user_data", [])
if not user_data:
user_data = [{"text": "Sample text for automatic training.", "label": 0}]
train_dataset = SyntheticDataset(tokenizers, user_data)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
training_args = TrainingArguments(
output_dir="memory",
evaluation_strategy="epoch",
learning_rate=5e-5,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
num_train_epochs=10,
weight_decay=0.01,
logging_steps=10,
optim="adamw_hf"
)
optimizer = AdamW(unified_model.parameters(), lr=training_args.learning_rate)
unified_model.train()
for epoch in range(training_args.num_train_epochs):
for batch in train_loader:
input_ids = [batch[f"input_ids_{name}"].to("cpu") for name in tokenizers.keys()]
attention_mask = [batch[f"attention_mask_{name}"].to("cpu") for name in tokenizers.keys()]
labels = batch["label"].to("cpu")
outputs = unified_model(input_ids=input_ids, attention_mask=attention_mask)
loss = nn.CrossEntropyLoss()(outputs, labels)
loss.backward()
optimizer.step()
optimizer.zero_grad()
print(f"Epoch {epoch}, Loss {loss.item()}")
print("Training complete.")
push_to_redis(models, tokenizers, redis_client, model_name, tokenizer_name)
return {"message": "Model trained and updated in Redis."}
elif data.get("predict"):
text = data['text']
tokenized_inputs = [tokenizers[name](text, return_tensors="pt") for name in tokenizers.keys()]
input_ids = [tokens['input_ids'] for tokens in tokenized_inputs]
attention_mask = [tokens['attention_mask'] for tokens in tokenized_inputs]
with torch.no_grad():
logits = unified_model(input_ids=input_ids, attention_mask=attention_mask)
predicted_class = torch.argmax(logits, dim=-1).item()
return {"prediction": predicted_class}
else:
raise HTTPException(status_code=400, detail="Request must contain 'train' or 'predict'.")
@app.post("/external_answer")
async def external_answer(request: Request):
data = await request.json()
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True)
question = data.get('question')
if not question:
raise HTTPException(status_code=400, detail="Question is required.")
model_name = "unified_model"
tokenizer_name = "unified_tokenizer"
model_data_bytes = redis_client.get(f"model:{model_name}")
tokenizer_data_bytes = redis_client.get(f"tokenizer:{tokenizer_name}")
if model_data_bytes:
model_data = json.loads(model_data_bytes)
model = AutoModel.from_pretrained("gpt2")
model.load_state_dict(torch.load(model_data))
else:
model = AutoModel.from_pretrained("gpt2")
if tokenizer_data_bytes:
tokenizer_data = json.loads(tokenizer_data_bytes)
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.add_tokens(tokenizer_data)
else:
tokenizer = AutoTokenizer.from_pretrained("gpt2")
unified_model = UnifiedModel([model])
unified_model.to(torch.device("cpu"))
tokenized_input = tokenizer(question, return_tensors="pt")
input_ids = tokenized_input['input_ids']
attention_mask = tokenized_input['attention_mask']
with torch.no_grad():
logits = unified_model(input_ids=input_ids, attention_mask=attention_mask)
predicted_class = torch.argmax(logits, dim=-1).item()
response = {"answer": f"Response to '{question}' is class {predicted_class}"}
extreme_training_data = [{"text": question, "label": predicted_class}]
train_dataset = SyntheticDataset({tokenizer_name: tokenizer}, extreme_training_data)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
training_args = TrainingArguments(
output_dir="memory",
evaluation_strategy="epoch",
learning_rate=5e-5,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
num_train_epochs=10,
weight_decay=0.01,
logging_steps=10,
optim="adamw_hf"
)
optimizer = AdamW(unified_model.parameters(), lr=training_args.learning_rate)
unified_model.train()
for epoch in range(training_args.num_train_epochs):
for batch in train_loader:
input_ids = [batch[f"input_ids_{name}"].to("cpu") for name in [tokenizer_name]]
attention_mask = [batch[f"attention_mask_{name}"].to("cpu") for name in [tokenizer_name]]
labels = batch["label"].to("cpu")
outputs = unified_model(input_ids=input_ids, attention_mask=attention_mask)
loss = nn.CrossEntropyLoss()(outputs, labels)
loss.backward()
optimizer.step()
optimizer.zero_grad()
print(f"Epoch {epoch}, Loss {loss.item()}")
print("Extreme training complete.")
push_to_redis({model_name: model}, {tokenizer_name: tokenizer}, redis_client, model_name, tokenizer_name)
return response
@app.get("/")
async def get_home():
html_code = """
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<title>Chatbot</title>
<style>
body {
font-family: Arial, sans-serif;
background-color: #f4f4f9;
margin: 0;
padding: 0;
}
.container {
max-width: 1200px;
margin: 0 auto;
padding: 20px;
}
h1 {
color: #333;
text-align: center;
}
.grid-container {
display: grid;
grid-template-columns: repeat(auto-fill, minmax(300px, 1fr));
gap: 10px;
margin-top: 20px;
}
.grid-item {
background: #fff;
padding: 20px;
border-radius: 8px;
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
transition: transform 0.3s;
}
.grid-item:hover {
transform: scale(1.05);
}
.question {
font-weight: bold;
color: #007bff;
}
.answer {
margin-top: 10px;
color: #333;
}
input[type="text"] {
width: calc(100% - 22px);
padding: 10px;
margin: 0;
border: 1px solid #ddd;
border-radius: 4px;
}
button {
padding: 10px 20px;
background-color: #007bff;
color: #fff;
border: none;
border-radius: 4px;
cursor: pointer;
margin-top: 10px;
}
button:hover {
background-color: #0056b3;
}
</style>
<script>
async function sendMessage() {
const question = document.getElementById('question').value;
const responseElement = document.getElementById('response');
const response = await fetch('/external_answer', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({ question: question })
});
const data = await response.json();
responseElement.innerText = "Response: " + data.answer;
const gridContainer = document.getElementById('grid-container');
const newItem = document.createElement('div');
newItem.classList.add('grid-item');
newItem.innerHTML = `<div class="question">${question}</div><div class="answer">${data.answer}</div>`;
gridContainer.prepend(newItem);
}
</script>
</head>
<body>
<div class="container">
<h1>Chatbot</h1>
<input type="text" id="question" placeholder="Ask me something...">
<button onclick="sendMessage()">Send</button>
<div id="response"></div>
<div class="grid-container" id="grid-container"></div>
</div>
</body>
</html>
"""
return HTMLResponse(content=html_code)
def push_to_redis(models, tokenizers, redis_client, model_name, tokenizer_name):
model_data = json.dumps(next(iter(models.values())).state_dict())
redis_client.set(f"model:{model_name}", model_data)
tokenizer_data = json.dumps(next(iter(tokenizers.values())).get_vocab())
redis_client.set(f"tokenizer:{tokenizer_name}", tokenizer_data)
def continuous_training():
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True)
model_name = "unified_model"
tokenizer_name = "unified_tokenizer"
while True:
try:
model_data_bytes = redis_client.get(f"model:{model_name}")
tokenizer_data_bytes = redis_client.get(f"tokenizer:{tokenizer_name}")
if model_data_bytes and tokenizer_data_bytes:
model_data = json.loads(model_data_bytes)
model = AutoModel.from_pretrained("gpt2")
model.load_state_dict(torch.load(model_data))
tokenizer_data = json.loads(tokenizer_data_bytes)
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.add_tokens(tokenizer_data)
unified_model = UnifiedModel([model])
unified_model.to(torch.device("cpu"))
train_data = [{"text": "Sample training text.", "label": 0}]
train_dataset = SyntheticDataset({tokenizer_name: tokenizer}, train_data)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
training_args = TrainingArguments(
output_dir="memory",
evaluation_strategy="epoch",
learning_rate=5e-5,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
num_train_epochs=10,
weight_decay=0.01,
logging_steps=10,
optim="adamw_hf"
)
optimizer = AdamW(unified_model.parameters(), lr=training_args.learning_rate)
unified_model.train()
for epoch in range(training_args.num_train_epochs):
for batch in train_loader:
input_ids = [batch[f"input_ids_{name}"].to("cpu") for name in [tokenizer_name]]
attention_mask = [batch[f"attention_mask_{name}"].to("cpu") for name in [tokenizer_name]]
labels = batch["label"].to("cpu")
outputs = unified_model(input_ids=input_ids, attention_mask=attention_mask)
loss = nn.CrossEntropyLoss()(outputs, labels)
loss.backward()
optimizer.step()
optimizer.zero_grad()
print(f"Epoch {epoch}, Loss {loss.item()}")
print("Training complete.")
push_to_redis({model_name: model}, {tokenizer_name: tokenizer}, redis_client, model_name, tokenizer_name)
else:
print("No model or tokenizer found in Redis. Skipping training.")
time.sleep(600)
except Exception as e:
print(f"An error occurred: {e}")
time.sleep(60)
def start_server():
import uvicorn
cpu_cores = os.cpu_count() or 1
num_workers = max(1, cpu_cores - 1)
uvicorn.run(app, host="0.0.0.0", port=7860, timeout_keep_alive=0)
if __name__ == "__main__":
api_process = multiprocessing.Process(target=start_server)
training_process = multiprocessing.Process(target=continuous_training)
api_process.start()
training_process.start()
api_process.join()
training_process.join()