from dotenv import load_dotenv import os import json import requests import redis from transformers import ( AutoTokenizer, AutoModel, TrainingArguments, ) import torch import torch.nn as nn from torch.utils.data import DataLoader, Dataset from torch.optim import AdamW from fastapi import FastAPI, HTTPException, Request from pydantic import BaseModel from typing import List, Dict from fastapi.responses import HTMLResponse import multiprocessing import time load_dotenv() REDIS_HOST = os.getenv('REDIS_HOST') REDIS_PORT = os.getenv('REDIS_PORT') REDIS_PASSWORD = os.getenv('REDIS_PASSWORD') app = FastAPI() class UnifiedModel(nn.Module): def __init__(self, models): super(UnifiedModel, self).__init__() self.models = nn.ModuleList(models) self.classifier = nn.Linear(sum([model.config.hidden_size for model in models]), 2) def forward(self, input_ids, attention_mask): hidden_states = [] for model, input_id, attn_mask in zip(self.models, input_ids, attention_mask): outputs = model( input_ids=input_id, attention_mask=attn_mask ) hidden_states.append(outputs.last_hidden_state[:, 0, :]) concatenated_hidden_states = torch.cat(hidden_states, dim=-1) logits = self.classifier(concatenated_hidden_states) return logits class SyntheticDataset(Dataset): def __init__(self, tokenizers, data): self.tokenizers = tokenizers self.data = data def __len__(self): return len(self.data) def __getitem__(self, idx): item = self.data[idx] text = item['text'] label = item['label'] tokenized = {} for name, tokenizer in self.tokenizers.items(): tokens = tokenizer(text, padding="max_length", truncation=True, max_length=128) tokenized[f"input_ids_{name}"] = torch.tensor(tokens["input_ids"]) tokenized[f"attention_mask_{name}"] = torch.tensor(tokens["attention_mask"]) tokenized["label"] = torch.tensor(label) return tokenized @app.post("/process") async def process(request: Request): data = await request.json() redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True) tokenizers = {} models = {} model_name = "unified_model" tokenizer_name = "unified_tokenizer" model_data_bytes = redis_client.get(f"model:{model_name}") tokenizer_data_bytes = redis_client.get(f"tokenizer:{tokenizer_name}") if model_data_bytes: model_data = json.loads(model_data_bytes) model = AutoModel.from_pretrained("gpt2") model.load_state_dict(torch.load(model_data)) else: model = AutoModel.from_pretrained("gpt2") models[model_name] = model if tokenizer_data_bytes: tokenizer_data = json.loads(tokenizer_data_bytes) tokenizer = AutoTokenizer.from_pretrained("gpt2") tokenizer.add_tokens(tokenizer_data) else: tokenizer = AutoTokenizer.from_pretrained("gpt2") tokenizers[tokenizer_name] = tokenizer unified_model = UnifiedModel(list(models.values())) unified_model.to(torch.device("cpu")) if data.get("train"): user_data = data.get("user_data", []) if not user_data: user_data = [{"text": "Sample text for automatic training.", "label": 0}] train_dataset = SyntheticDataset(tokenizers, user_data) train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True) training_args = TrainingArguments( output_dir="memory", evaluation_strategy="epoch", learning_rate=5e-5, per_device_train_batch_size=8, per_device_eval_batch_size=8, num_train_epochs=10, weight_decay=0.01, logging_steps=10, optim="adamw_hf" ) optimizer = AdamW(unified_model.parameters(), lr=training_args.learning_rate) unified_model.train() for epoch in range(training_args.num_train_epochs): for batch in train_loader: input_ids = [batch[f"input_ids_{name}"].to("cpu") for name in tokenizers.keys()] attention_mask = [batch[f"attention_mask_{name}"].to("cpu") for name in tokenizers.keys()] labels = batch["label"].to("cpu") outputs = unified_model(input_ids=input_ids, attention_mask=attention_mask) loss = nn.CrossEntropyLoss()(outputs, labels) loss.backward() optimizer.step() optimizer.zero_grad() print(f"Epoch {epoch}, Loss {loss.item()}") print("Training complete.") push_to_redis(models, tokenizers, redis_client, model_name, tokenizer_name) return {"message": "Model trained and updated in Redis."} elif data.get("predict"): text = data['text'] tokenized_inputs = [tokenizers[name](text, return_tensors="pt") for name in tokenizers.keys()] input_ids = [tokens['input_ids'] for tokens in tokenized_inputs] attention_mask = [tokens['attention_mask'] for tokens in tokenized_inputs] with torch.no_grad(): logits = unified_model(input_ids=input_ids, attention_mask=attention_mask) predicted_class = torch.argmax(logits, dim=-1).item() return {"prediction": predicted_class} else: raise HTTPException(status_code=400, detail="Request must contain 'train' or 'predict'.") @app.post("/external_answer") async def external_answer(request: Request): data = await request.json() redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True) question = data.get('question') if not question: raise HTTPException(status_code=400, detail="Question is required.") model_name = "unified_model" tokenizer_name = "unified_tokenizer" model_data_bytes = redis_client.get(f"model:{model_name}") tokenizer_data_bytes = redis_client.get(f"tokenizer:{tokenizer_name}") if model_data_bytes: model_data = json.loads(model_data_bytes) model = AutoModel.from_pretrained("gpt2") model.load_state_dict(torch.load(model_data)) else: model = AutoModel.from_pretrained("gpt2") if tokenizer_data_bytes: tokenizer_data = json.loads(tokenizer_data_bytes) tokenizer = AutoTokenizer.from_pretrained("gpt2") tokenizer.add_tokens(tokenizer_data) else: tokenizer = AutoTokenizer.from_pretrained("gpt2") unified_model = UnifiedModel([model]) unified_model.to(torch.device("cpu")) tokenized_input = tokenizer(question, return_tensors="pt") input_ids = tokenized_input['input_ids'] attention_mask = tokenized_input['attention_mask'] with torch.no_grad(): logits = unified_model(input_ids=input_ids, attention_mask=attention_mask) predicted_class = torch.argmax(logits, dim=-1).item() response = {"answer": f"Response to '{question}' is class {predicted_class}"} extreme_training_data = [{"text": question, "label": predicted_class}] train_dataset = SyntheticDataset({tokenizer_name: tokenizer}, extreme_training_data) train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True) training_args = TrainingArguments( output_dir="memory", evaluation_strategy="epoch", learning_rate=5e-5, per_device_train_batch_size=8, per_device_eval_batch_size=8, num_train_epochs=10, weight_decay=0.01, logging_steps=10, optim="adamw_hf" ) optimizer = AdamW(unified_model.parameters(), lr=training_args.learning_rate) unified_model.train() for epoch in range(training_args.num_train_epochs): for batch in train_loader: input_ids = [batch[f"input_ids_{name}"].to("cpu") for name in [tokenizer_name]] attention_mask = [batch[f"attention_mask_{name}"].to("cpu") for name in [tokenizer_name]] labels = batch["label"].to("cpu") outputs = unified_model(input_ids=input_ids, attention_mask=attention_mask) loss = nn.CrossEntropyLoss()(outputs, labels) loss.backward() optimizer.step() optimizer.zero_grad() print(f"Epoch {epoch}, Loss {loss.item()}") print("Extreme training complete.") push_to_redis({model_name: model}, {tokenizer_name: tokenizer}, redis_client, model_name, tokenizer_name) return response @app.get("/") async def get_home(): html_code = """