|
from dotenv import load_dotenv |
|
import os |
|
import json |
|
import requests |
|
import redis |
|
from transformers import ( |
|
AutoTokenizer, |
|
AutoModel, |
|
TrainingArguments, |
|
) |
|
import torch |
|
import torch.nn as nn |
|
from torch.utils.data import DataLoader, Dataset |
|
from torch.optim import AdamW |
|
from fastapi import FastAPI, HTTPException, Request |
|
from pydantic import BaseModel |
|
from typing import List, Dict |
|
from fastapi.responses import HTMLResponse |
|
import multiprocessing |
|
import time |
|
|
|
load_dotenv() |
|
|
|
REDIS_HOST = os.getenv('REDIS_HOST') |
|
REDIS_PORT = os.getenv('REDIS_PORT') |
|
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD') |
|
|
|
app = FastAPI() |
|
|
|
class UnifiedModel(nn.Module): |
|
def __init__(self, models): |
|
super(UnifiedModel, self).__init__() |
|
self.models = nn.ModuleList(models) |
|
self.classifier = nn.Linear(sum([model.config.hidden_size for model in models]), 2) |
|
|
|
def forward(self, input_ids, attention_mask): |
|
hidden_states = [] |
|
for model, input_id, attn_mask in zip(self.models, input_ids, attention_mask): |
|
outputs = model( |
|
input_ids=input_id, |
|
attention_mask=attn_mask |
|
) |
|
hidden_states.append(outputs.last_hidden_state[:, 0, :]) |
|
concatenated_hidden_states = torch.cat(hidden_states, dim=-1) |
|
logits = self.classifier(concatenated_hidden_states) |
|
return logits |
|
|
|
class SyntheticDataset(Dataset): |
|
def __init__(self, tokenizers, data): |
|
self.tokenizers = tokenizers |
|
self.data = data |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
item = self.data[idx] |
|
text = item['text'] |
|
label = item['label'] |
|
tokenized = {} |
|
for name, tokenizer in self.tokenizers.items(): |
|
tokens = tokenizer(text, padding="max_length", truncation=True, max_length=128) |
|
tokenized[f"input_ids_{name}"] = torch.tensor(tokens["input_ids"]) |
|
tokenized[f"attention_mask_{name}"] = torch.tensor(tokens["attention_mask"]) |
|
tokenized["label"] = torch.tensor(label) |
|
return tokenized |
|
|
|
@app.post("/process") |
|
async def process(request: Request): |
|
data = await request.json() |
|
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True) |
|
|
|
tokenizers = {} |
|
models = {} |
|
|
|
model_name = "unified_model" |
|
tokenizer_name = "unified_tokenizer" |
|
|
|
model_data_bytes = redis_client.get(f"model:{model_name}") |
|
tokenizer_data_bytes = redis_client.get(f"tokenizer:{tokenizer_name}") |
|
|
|
if model_data_bytes: |
|
model_data = json.loads(model_data_bytes) |
|
model = AutoModel.from_pretrained("gpt2") |
|
model.load_state_dict(torch.load(model_data)) |
|
else: |
|
model = AutoModel.from_pretrained("gpt2") |
|
models[model_name] = model |
|
|
|
if tokenizer_data_bytes: |
|
tokenizer_data = json.loads(tokenizer_data_bytes) |
|
tokenizer = AutoTokenizer.from_pretrained("gpt2") |
|
tokenizer.add_tokens(tokenizer_data) |
|
else: |
|
tokenizer = AutoTokenizer.from_pretrained("gpt2") |
|
tokenizers[tokenizer_name] = tokenizer |
|
|
|
unified_model = UnifiedModel(list(models.values())) |
|
unified_model.to(torch.device("cpu")) |
|
|
|
if data.get("train"): |
|
user_data = data.get("user_data", []) |
|
if not user_data: |
|
user_data = [{"text": "Sample text for automatic training.", "label": 0}] |
|
|
|
train_dataset = SyntheticDataset(tokenizers, user_data) |
|
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True) |
|
|
|
training_args = TrainingArguments( |
|
output_dir="memory", |
|
evaluation_strategy="epoch", |
|
learning_rate=5e-5, |
|
per_device_train_batch_size=8, |
|
per_device_eval_batch_size=8, |
|
num_train_epochs=10, |
|
weight_decay=0.01, |
|
logging_steps=10, |
|
optim="adamw_hf" |
|
) |
|
|
|
optimizer = AdamW(unified_model.parameters(), lr=training_args.learning_rate) |
|
unified_model.train() |
|
|
|
for epoch in range(training_args.num_train_epochs): |
|
for batch in train_loader: |
|
input_ids = [batch[f"input_ids_{name}"].to("cpu") for name in tokenizers.keys()] |
|
attention_mask = [batch[f"attention_mask_{name}"].to("cpu") for name in tokenizers.keys()] |
|
labels = batch["label"].to("cpu") |
|
outputs = unified_model(input_ids=input_ids, attention_mask=attention_mask) |
|
loss = nn.CrossEntropyLoss()(outputs, labels) |
|
loss.backward() |
|
optimizer.step() |
|
optimizer.zero_grad() |
|
|
|
print(f"Epoch {epoch}, Loss {loss.item()}") |
|
|
|
print("Training complete.") |
|
|
|
push_to_redis(models, tokenizers, redis_client, model_name, tokenizer_name) |
|
return {"message": "Model trained and updated in Redis."} |
|
|
|
elif data.get("predict"): |
|
text = data['text'] |
|
tokenized_inputs = [tokenizers[name](text, return_tensors="pt") for name in tokenizers.keys()] |
|
input_ids = [tokens['input_ids'] for tokens in tokenized_inputs] |
|
attention_mask = [tokens['attention_mask'] for tokens in tokenized_inputs] |
|
|
|
with torch.no_grad(): |
|
logits = unified_model(input_ids=input_ids, attention_mask=attention_mask) |
|
predicted_class = torch.argmax(logits, dim=-1).item() |
|
|
|
return {"prediction": predicted_class} |
|
|
|
else: |
|
raise HTTPException(status_code=400, detail="Request must contain 'train' or 'predict'.") |
|
|
|
@app.post("/external_answer") |
|
async def external_answer(request: Request): |
|
data = await request.json() |
|
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True) |
|
|
|
question = data.get('question') |
|
if not question: |
|
raise HTTPException(status_code=400, detail="Question is required.") |
|
|
|
model_name = "unified_model" |
|
tokenizer_name = "unified_tokenizer" |
|
|
|
model_data_bytes = redis_client.get(f"model:{model_name}") |
|
tokenizer_data_bytes = redis_client.get(f"tokenizer:{tokenizer_name}") |
|
|
|
if model_data_bytes: |
|
model_data = json.loads(model_data_bytes) |
|
model = AutoModel.from_pretrained("gpt2") |
|
model.load_state_dict(torch.load(model_data)) |
|
else: |
|
model = AutoModel.from_pretrained("gpt2") |
|
|
|
if tokenizer_data_bytes: |
|
tokenizer_data = json.loads(tokenizer_data_bytes) |
|
tokenizer = AutoTokenizer.from_pretrained("gpt2") |
|
tokenizer.add_tokens(tokenizer_data) |
|
else: |
|
tokenizer = AutoTokenizer.from_pretrained("gpt2") |
|
|
|
unified_model = UnifiedModel([model]) |
|
unified_model.to(torch.device("cpu")) |
|
|
|
tokenized_input = tokenizer(question, return_tensors="pt") |
|
input_ids = tokenized_input['input_ids'] |
|
attention_mask = tokenized_input['attention_mask'] |
|
|
|
with torch.no_grad(): |
|
logits = unified_model(input_ids=input_ids, attention_mask=attention_mask) |
|
predicted_class = torch.argmax(logits, dim=-1).item() |
|
response = {"answer": f"Response to '{question}' is class {predicted_class}"} |
|
|
|
extreme_training_data = [{"text": question, "label": predicted_class}] |
|
train_dataset = SyntheticDataset({tokenizer_name: tokenizer}, extreme_training_data) |
|
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True) |
|
|
|
training_args = TrainingArguments( |
|
output_dir="memory", |
|
evaluation_strategy="epoch", |
|
learning_rate=5e-5, |
|
per_device_train_batch_size=8, |
|
per_device_eval_batch_size=8, |
|
num_train_epochs=10, |
|
weight_decay=0.01, |
|
logging_steps=10, |
|
optim="adamw_hf" |
|
) |
|
|
|
optimizer = AdamW(unified_model.parameters(), lr=training_args.learning_rate) |
|
unified_model.train() |
|
|
|
for epoch in range(training_args.num_train_epochs): |
|
for batch in train_loader: |
|
input_ids = [batch[f"input_ids_{name}"].to("cpu") for name in [tokenizer_name]] |
|
attention_mask = [batch[f"attention_mask_{name}"].to("cpu") for name in [tokenizer_name]] |
|
labels = batch["label"].to("cpu") |
|
outputs = unified_model(input_ids=input_ids, attention_mask=attention_mask) |
|
loss = nn.CrossEntropyLoss()(outputs, labels) |
|
loss.backward() |
|
optimizer.step() |
|
optimizer.zero_grad() |
|
|
|
print(f"Epoch {epoch}, Loss {loss.item()}") |
|
|
|
print("Extreme training complete.") |
|
push_to_redis({model_name: model}, {tokenizer_name: tokenizer}, redis_client, model_name, tokenizer_name) |
|
|
|
return response |
|
|
|
@app.get("/") |
|
async def get_home(): |
|
html_code = """ |
|
<!DOCTYPE html> |
|
<html> |
|
<head> |
|
<meta charset="UTF-8"> |
|
<title>Chatbot</title> |
|
<style> |
|
body { |
|
font-family: Arial, sans-serif; |
|
background-color: #f4f4f9; |
|
margin: 0; |
|
padding: 0; |
|
} |
|
.container { |
|
max-width: 1200px; |
|
margin: 0 auto; |
|
padding: 20px; |
|
} |
|
h1 { |
|
color: #333; |
|
text-align: center; |
|
} |
|
.grid-container { |
|
display: grid; |
|
grid-template-columns: repeat(auto-fill, minmax(300px, 1fr)); |
|
gap: 10px; |
|
margin-top: 20px; |
|
} |
|
.grid-item { |
|
background: #fff; |
|
padding: 20px; |
|
border-radius: 8px; |
|
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1); |
|
transition: transform 0.3s; |
|
} |
|
.grid-item:hover { |
|
transform: scale(1.05); |
|
} |
|
.question { |
|
font-weight: bold; |
|
color: #007bff; |
|
} |
|
.answer { |
|
margin-top: 10px; |
|
color: #333; |
|
} |
|
input[type="text"] { |
|
width: calc(100% - 22px); |
|
padding: 10px; |
|
margin: 0; |
|
border: 1px solid #ddd; |
|
border-radius: 4px; |
|
} |
|
button { |
|
padding: 10px 20px; |
|
background-color: #007bff; |
|
color: #fff; |
|
border: none; |
|
border-radius: 4px; |
|
cursor: pointer; |
|
margin-top: 10px; |
|
} |
|
button:hover { |
|
background-color: #0056b3; |
|
} |
|
</style> |
|
<script> |
|
async function sendMessage() { |
|
const question = document.getElementById('question').value; |
|
const responseElement = document.getElementById('response'); |
|
|
|
const response = await fetch('/external_answer', { |
|
method: 'POST', |
|
headers: { |
|
'Content-Type': 'application/json', |
|
}, |
|
body: JSON.stringify({ question: question }) |
|
}); |
|
|
|
const data = await response.json(); |
|
responseElement.innerText = "Response: " + data.answer; |
|
|
|
const gridContainer = document.getElementById('grid-container'); |
|
const newItem = document.createElement('div'); |
|
newItem.classList.add('grid-item'); |
|
newItem.innerHTML = `<div class="question">${question}</div><div class="answer">${data.answer}</div>`; |
|
gridContainer.prepend(newItem); |
|
} |
|
</script> |
|
</head> |
|
<body> |
|
<div class="container"> |
|
<h1>Chatbot</h1> |
|
<input type="text" id="question" placeholder="Ask me something..."> |
|
<button onclick="sendMessage()">Send</button> |
|
<div id="response"></div> |
|
<div class="grid-container" id="grid-container"></div> |
|
</div> |
|
</body> |
|
</html> |
|
""" |
|
return HTMLResponse(content=html_code) |
|
|
|
def push_to_redis(models, tokenizers, redis_client, model_name, tokenizer_name): |
|
model_data = json.dumps(next(iter(models.values())).state_dict()) |
|
redis_client.set(f"model:{model_name}", model_data) |
|
|
|
tokenizer_data = json.dumps(next(iter(tokenizers.values())).get_vocab()) |
|
redis_client.set(f"tokenizer:{tokenizer_name}", tokenizer_data) |
|
|
|
def continuous_training(): |
|
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True) |
|
|
|
model_name = "unified_model" |
|
tokenizer_name = "unified_tokenizer" |
|
|
|
while True: |
|
try: |
|
model_data_bytes = redis_client.get(f"model:{model_name}") |
|
tokenizer_data_bytes = redis_client.get(f"tokenizer:{tokenizer_name}") |
|
|
|
if model_data_bytes and tokenizer_data_bytes: |
|
model_data = json.loads(model_data_bytes) |
|
model = AutoModel.from_pretrained("gpt2") |
|
model.load_state_dict(torch.load(model_data)) |
|
|
|
tokenizer_data = json.loads(tokenizer_data_bytes) |
|
tokenizer = AutoTokenizer.from_pretrained("gpt2") |
|
tokenizer.add_tokens(tokenizer_data) |
|
|
|
unified_model = UnifiedModel([model]) |
|
unified_model.to(torch.device("cpu")) |
|
|
|
train_data = [{"text": "Sample training text.", "label": 0}] |
|
train_dataset = SyntheticDataset({tokenizer_name: tokenizer}, train_data) |
|
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True) |
|
|
|
training_args = TrainingArguments( |
|
output_dir="memory", |
|
evaluation_strategy="epoch", |
|
learning_rate=5e-5, |
|
per_device_train_batch_size=8, |
|
per_device_eval_batch_size=8, |
|
num_train_epochs=10, |
|
weight_decay=0.01, |
|
logging_steps=10, |
|
optim="adamw_hf" |
|
) |
|
|
|
optimizer = AdamW(unified_model.parameters(), lr=training_args.learning_rate) |
|
unified_model.train() |
|
|
|
for epoch in range(training_args.num_train_epochs): |
|
for batch in train_loader: |
|
input_ids = [batch[f"input_ids_{name}"].to("cpu") for name in [tokenizer_name]] |
|
attention_mask = [batch[f"attention_mask_{name}"].to("cpu") for name in [tokenizer_name]] |
|
labels = batch["label"].to("cpu") |
|
outputs = unified_model(input_ids=input_ids, attention_mask=attention_mask) |
|
loss = nn.CrossEntropyLoss()(outputs, labels) |
|
loss.backward() |
|
optimizer.step() |
|
optimizer.zero_grad() |
|
|
|
print(f"Epoch {epoch}, Loss {loss.item()}") |
|
|
|
print("Training complete.") |
|
push_to_redis({model_name: model}, {tokenizer_name: tokenizer}, redis_client, model_name, tokenizer_name) |
|
else: |
|
print("No model or tokenizer found in Redis. Skipping training.") |
|
|
|
time.sleep(600) |
|
|
|
except Exception as e: |
|
print(f"An error occurred: {e}") |
|
time.sleep(60) |
|
|
|
def start_server(): |
|
import uvicorn |
|
cpu_cores = os.cpu_count() or 1 |
|
num_workers = max(1, cpu_cores - 1) |
|
|
|
uvicorn.run(app, host="0.0.0.0", port=7860, timeout_keep_alive=0) |
|
|
|
if __name__ == "__main__": |
|
api_process = multiprocessing.Process(target=start_server) |
|
training_process = multiprocessing.Process(target=continuous_training) |
|
|
|
api_process.start() |
|
training_process.start() |
|
|
|
api_process.join() |
|
training_process.join() |