from dotenv import load_dotenv import os import json import requests import redis from transformers import ( AutoTokenizer, AutoModel, TrainingArguments, ) import torch import torch.nn as nn from torch.utils.data import DataLoader, Dataset from torch.optim import AdamW from fastapi import FastAPI, HTTPException, Request from pydantic import BaseModel from typing import List, Dict from fastapi.responses import HTMLResponse import multiprocessing import time load_dotenv() REDIS_HOST = os.getenv('REDIS_HOST') REDIS_PORT = os.getenv('REDIS_PORT') REDIS_PASSWORD = os.getenv('REDIS_PASSWORD') app = FastAPI() class UnifiedModel(nn.Module): def __init__(self, models): super(UnifiedModel, self).__init__() self.models = nn.ModuleList(models) self.classifier = nn.Linear(sum([model.config.hidden_size for model in models]), 2) def forward(self, input_ids, attention_mask): hidden_states = [] for model, input_id, attn_mask in zip(self.models, input_ids, attention_mask): outputs = model( input_ids=input_id, attention_mask=attn_mask ) hidden_states.append(outputs.last_hidden_state[:, 0, :]) concatenated_hidden_states = torch.cat(hidden_states, dim=-1) logits = self.classifier(concatenated_hidden_states) return logits class SyntheticDataset(Dataset): def __init__(self, tokenizers, data): self.tokenizers = tokenizers self.data = data def __len__(self): return len(self.data) def __getitem__(self, idx): item = self.data[idx] text = item['text'] label = item['label'] tokenized = {} for name, tokenizer in self.tokenizers.items(): tokens = tokenizer(text, padding="max_length", truncation=True, max_length=128) tokenized[f"input_ids_{name}"] = torch.tensor(tokens["input_ids"]) tokenized[f"attention_mask_{name}"] = torch.tensor(tokens["attention_mask"]) tokenized["label"] = torch.tensor(label) return tokenized @app.post("/process") async def process(request: Request): data = await request.json() redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True) tokenizers = {} models = {} model_name = "unified_model" tokenizer_name = "unified_tokenizer" model_data_bytes = redis_client.get(f"model:{model_name}") tokenizer_data_bytes = redis_client.get(f"tokenizer:{tokenizer_name}") if model_data_bytes: model_data = json.loads(model_data_bytes) model = AutoModel.from_pretrained("gpt2") model.load_state_dict(torch.load(model_data)) else: model = AutoModel.from_pretrained("gpt2") models[model_name] = model if tokenizer_data_bytes: tokenizer_data = json.loads(tokenizer_data_bytes) tokenizer = AutoTokenizer.from_pretrained("gpt2") tokenizer.add_tokens(tokenizer_data) else: tokenizer = AutoTokenizer.from_pretrained("gpt2") tokenizers[tokenizer_name] = tokenizer unified_model = UnifiedModel(list(models.values())) unified_model.to(torch.device("cpu")) if data.get("train"): user_data = data.get("user_data", []) if not user_data: user_data = [{"text": "Sample text for automatic training.", "label": 0}] train_dataset = SyntheticDataset(tokenizers, user_data) train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True) training_args = TrainingArguments( output_dir="memory", evaluation_strategy="epoch", learning_rate=5e-5, per_device_train_batch_size=8, per_device_eval_batch_size=8, num_train_epochs=10, weight_decay=0.01, logging_steps=10, optim="adamw_hf" ) optimizer = AdamW(unified_model.parameters(), lr=training_args.learning_rate) unified_model.train() for epoch in range(training_args.num_train_epochs): for batch in train_loader: input_ids = [batch[f"input_ids_{name}"].to("cpu") for name in tokenizers.keys()] attention_mask = [batch[f"attention_mask_{name}"].to("cpu") for name in tokenizers.keys()] labels = batch["label"].to("cpu") outputs = unified_model(input_ids=input_ids, attention_mask=attention_mask) loss = nn.CrossEntropyLoss()(outputs, labels) loss.backward() optimizer.step() optimizer.zero_grad() print(f"Epoch {epoch}, Loss {loss.item()}") print("Training complete.") push_to_redis(models, tokenizers, redis_client, model_name, tokenizer_name) return {"message": "Model trained and updated in Redis."} elif data.get("predict"): text = data['text'] tokenized_inputs = [tokenizers[name](text, return_tensors="pt") for name in tokenizers.keys()] input_ids = [tokens['input_ids'] for tokens in tokenized_inputs] attention_mask = [tokens['attention_mask'] for tokens in tokenized_inputs] with torch.no_grad(): logits = unified_model(input_ids=input_ids, attention_mask=attention_mask) predicted_class = torch.argmax(logits, dim=-1).item() return {"prediction": predicted_class} else: raise HTTPException(status_code=400, detail="Request must contain 'train' or 'predict'.") @app.post("/external_answer") async def external_answer(request: Request): data = await request.json() redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True) question = data.get('question') if not question: raise HTTPException(status_code=400, detail="Question is required.") model_name = "unified_model" tokenizer_name = "unified_tokenizer" model_data_bytes = redis_client.get(f"model:{model_name}") tokenizer_data_bytes = redis_client.get(f"tokenizer:{tokenizer_name}") if model_data_bytes: model_data = json.loads(model_data_bytes) model = AutoModel.from_pretrained("gpt2") model.load_state_dict(torch.load(model_data)) else: model = AutoModel.from_pretrained("gpt2") if tokenizer_data_bytes: tokenizer_data = json.loads(tokenizer_data_bytes) tokenizer = AutoTokenizer.from_pretrained("gpt2") tokenizer.add_tokens(tokenizer_data) else: tokenizer = AutoTokenizer.from_pretrained("gpt2") unified_model = UnifiedModel([model]) unified_model.to(torch.device("cpu")) tokenized_input = tokenizer(question, return_tensors="pt") input_ids = tokenized_input['input_ids'] attention_mask = tokenized_input['attention_mask'] with torch.no_grad(): logits = unified_model(input_ids=input_ids, attention_mask=attention_mask) predicted_class = torch.argmax(logits, dim=-1).item() response = {"answer": f"Response to '{question}' is class {predicted_class}"} extreme_training_data = [{"text": question, "label": predicted_class}] train_dataset = SyntheticDataset({tokenizer_name: tokenizer}, extreme_training_data) train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True) training_args = TrainingArguments( output_dir="memory", evaluation_strategy="epoch", learning_rate=5e-5, per_device_train_batch_size=8, per_device_eval_batch_size=8, num_train_epochs=10, weight_decay=0.01, logging_steps=10, optim="adamw_hf" ) optimizer = AdamW(unified_model.parameters(), lr=training_args.learning_rate) unified_model.train() for epoch in range(training_args.num_train_epochs): for batch in train_loader: input_ids = [batch[f"input_ids_{name}"].to("cpu") for name in [tokenizer_name]] attention_mask = [batch[f"attention_mask_{name}"].to("cpu") for name in [tokenizer_name]] labels = batch["label"].to("cpu") outputs = unified_model(input_ids=input_ids, attention_mask=attention_mask) loss = nn.CrossEntropyLoss()(outputs, labels) loss.backward() optimizer.step() optimizer.zero_grad() print(f"Epoch {epoch}, Loss {loss.item()}") print("Extreme training complete.") push_to_redis({model_name: model}, {tokenizer_name: tokenizer}, redis_client, model_name, tokenizer_name) return response @app.get("/") async def get_home(): html_code = """ Chatbot

Chatbot

""" return HTMLResponse(content=html_code) def push_to_redis(models, tokenizers, redis_client, model_name, tokenizer_name): model_data = json.dumps(next(iter(models.values())).state_dict()) redis_client.set(f"model:{model_name}", model_data) tokenizer_data = json.dumps(next(iter(tokenizers.values())).get_vocab()) redis_client.set(f"tokenizer:{tokenizer_name}", tokenizer_data) def continuous_training(): redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True) model_name = "unified_model" tokenizer_name = "unified_tokenizer" while True: try: model_data_bytes = redis_client.get(f"model:{model_name}") tokenizer_data_bytes = redis_client.get(f"tokenizer:{tokenizer_name}") if model_data_bytes and tokenizer_data_bytes: model_data = json.loads(model_data_bytes) model = AutoModel.from_pretrained("gpt2") model.load_state_dict(torch.load(model_data)) tokenizer_data = json.loads(tokenizer_data_bytes) tokenizer = AutoTokenizer.from_pretrained("gpt2") tokenizer.add_tokens(tokenizer_data) unified_model = UnifiedModel([model]) unified_model.to(torch.device("cpu")) train_data = [{"text": "Sample training text.", "label": 0}] train_dataset = SyntheticDataset({tokenizer_name: tokenizer}, train_data) train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True) training_args = TrainingArguments( output_dir="memory", evaluation_strategy="epoch", learning_rate=5e-5, per_device_train_batch_size=8, per_device_eval_batch_size=8, num_train_epochs=10, weight_decay=0.01, logging_steps=10, optim="adamw_hf" ) optimizer = AdamW(unified_model.parameters(), lr=training_args.learning_rate) unified_model.train() for epoch in range(training_args.num_train_epochs): for batch in train_loader: input_ids = [batch[f"input_ids_{name}"].to("cpu") for name in [tokenizer_name]] attention_mask = [batch[f"attention_mask_{name}"].to("cpu") for name in [tokenizer_name]] labels = batch["label"].to("cpu") outputs = unified_model(input_ids=input_ids, attention_mask=attention_mask) loss = nn.CrossEntropyLoss()(outputs, labels) loss.backward() optimizer.step() optimizer.zero_grad() print(f"Epoch {epoch}, Loss {loss.item()}") print("Training complete.") push_to_redis({model_name: model}, {tokenizer_name: tokenizer}, redis_client, model_name, tokenizer_name) else: print("No model or tokenizer found in Redis. Skipping training.") time.sleep(600) except Exception as e: print(f"An error occurred: {e}") time.sleep(60) def start_server(): import uvicorn cpu_cores = os.cpu_count() or 1 num_workers = max(1, cpu_cores - 1) uvicorn.run(app, host="0.0.0.0", port=7860, timeout_keep_alive=0) if __name__ == "__main__": api_process = multiprocessing.Process(target=start_server) training_process = multiprocessing.Process(target=continuous_training) api_process.start() training_process.start() api_process.join() training_process.join()