# app.py import gradio as gr import torch import numpy as np from PIL import Image from torch import nn import torch.nn.functional as F from datasets import load_dataset from torch.utils.data import Dataset, DataLoader import os from tqdm import tqdm from transformers import AutoProcessor, AutoModelForCausalLM class SDDataset(Dataset): def __init__(self, dataset, processor, model_to_idx, token_to_idx, max_samples=5000): self.dataset = dataset.select(range(min(max_samples, len(dataset)))) self.processor = processor self.model_to_idx = model_to_idx self.token_to_idx = token_to_idx def __len__(self): return len(self.dataset) def __getitem__(self, idx): item = self.dataset[idx] # Process image image = Image.open(item['image']) image_inputs = self.processor(images=image, return_tensors="pt") # Create model label model_label = torch.zeros(len(self.model_to_idx)) model_label[self.model_to_idx[item['model_name']]] = 1 # Create prompt label (multi-hot encoding) prompt_label = torch.zeros(len(self.token_to_idx)) for token in item['prompt'].split(): if token in self.token_to_idx: prompt_label[self.token_to_idx[token]] = 1 return image_inputs, model_label, prompt_label class SDRecommenderModel(nn.Module): def __init__(self, florence_model, num_models, vocab_size): super().__init__() self.florence = florence_model hidden_size = 1024 # Florence-2-large hidden size self.model_head = nn.Linear(hidden_size, num_models) self.prompt_head = nn.Linear(hidden_size, vocab_size) def forward(self, pixel_values): # Get Florence embeddings outputs = self.florence(pixel_values=pixel_values, output_hidden_states=True) features = outputs.hidden_states[-1].mean(dim=1) # Use mean pooling of last hidden state # Generate model and prompt recommendations model_logits = self.model_head(features) prompt_logits = self.prompt_head(features) return model_logits, prompt_logits class SDRecommender: def __init__(self, max_samples=500): self.device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {self.device}") # Load Florence model and processor print("Loading Florence model and processor...") self.processor = AutoProcessor.from_pretrained( "microsoft/Florence-2-large", trust_remote_code=True ) self.florence = AutoModelForCausalLM.from_pretrained( "microsoft/Florence-2-large", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, trust_remote_code=True ).to(self.device) # Load dataset print("Loading dataset...") self.dataset = load_dataset("thefcraft/civitai-stable-diffusion-337k", split="train") self.dataset = self.dataset.select(range(min(max_samples, len(self.dataset)))) print(f"Using {len(self.dataset)} samples from dataset") # Create vocabularies for models and tokens self.model_to_idx = self._create_model_vocab() self.token_to_idx = self._create_prompt_vocab() # Initialize the recommendation model self.model = SDRecommenderModel( self.florence, len(self.model_to_idx), len(self.token_to_idx) ).to(self.device) # Load trained weights if available if os.path.exists("recommender_model.pth"): self.model.load_state_dict(torch.load("recommender_model.pth", map_location=self.device)) print("Loaded trained model weights") self.model.eval() def _create_model_vocab(self): print("Creating model vocabulary...") models = set() for item in self.dataset: models.add(item["model_name"]) return {model: idx for idx, model in enumerate(sorted(models))} def _create_prompt_vocab(self): print("Creating prompt vocabulary...") tokens = set() for item in self.dataset: for token in item["prompt"].split(): tokens.add(token) return {token: idx for idx, token in enumerate(sorted(tokens))} def train(self, num_epochs=5, batch_size=8, learning_rate=1e-4): print("Starting training...") # Create dataset and dataloader train_dataset = SDDataset( self.dataset, self.processor, self.model_to_idx, self.token_to_idx ) train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=2 ) # Setup optimizer optimizer = torch.optim.AdamW(self.model.parameters(), lr=learning_rate) # Training loop self.model.train() for epoch in range(num_epochs): total_loss = 0 progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}") for batch_idx, (images, model_labels, prompt_labels) in enumerate(progress_bar): # Move everything to device images = {k: v.to(self.device) for k, v in images.items()} model_labels = model_labels.to(self.device) prompt_labels = prompt_labels.to(self.device) # Forward pass model_logits, prompt_logits = self.model(images) # Calculate loss model_loss = F.cross_entropy(model_logits, model_labels) prompt_loss = F.binary_cross_entropy_with_logits(prompt_logits, prompt_labels) loss = model_loss + prompt_loss # Backward pass optimizer.zero_grad() loss.backward() optimizer.step() # Update progress total_loss += loss.item() progress_bar.set_postfix({"loss": total_loss / (batch_idx + 1)}) # Save trained model torch.save(self.model.state_dict(), "recommender_model.pth") print("Training completed and model saved") def get_recommendations(self, image): # Convert uploaded image to PIL if needed if not isinstance(image, Image.Image): image = Image.open(image) # Process image inputs = self.processor(images=image, return_tensors="pt") inputs = {k: v.to(self.device) for k, v in inputs.items()} # Get model predictions self.model.eval() with torch.no_grad(): model_logits, prompt_logits = self.model(inputs) # Get top 5 model recommendations model_probs = F.softmax(model_logits, dim=-1) top_models = torch.topk(model_probs, k=5) model_recommendations = [ (list(self.model_to_idx.keys())[idx.item()], prob.item()) for prob, idx in zip(top_models.values[0], top_models.indices[0]) ] # Get top tokens for prompt recommendations prompt_probs = F.softmax(prompt_logits, dim=-1) top_tokens = torch.topk(prompt_probs, k=20) recommended_tokens = [ list(self.token_to_idx.keys())[idx.item()] for idx in top_tokens.indices[0] ] # Create 5 prompt combinations prompt_recommendations = [ " ".join(np.random.choice(recommended_tokens, size=8, replace=False)) for _ in range(5) ] return ( "\n".join(f"{model} (confidence: {conf:.2f})" for model, conf in model_recommendations), "\n".join(prompt_recommendations) ) # Create Gradio interface def create_interface(): recommender = SDRecommender(max_samples=5000) # Train the model if no trained weights exist if not os.path.exists("recommender_model.pth"): recommender.train() def process_image(image): model_recs, prompt_recs = recommender.get_recommendations(image) return model_recs, prompt_recs interface = gr.Interface( fn=process_image, inputs=gr.Image(type="pil"), outputs=[ gr.Textbox(label="Recommended Models"), gr.Textbox(label="Recommended Prompts") ], title="Stable Diffusion Model & Prompt Recommender", description="Upload an AI-generated image to get model and prompt recommendations", ) return interface # Launch the interface if __name__ == "__main__": interface = create_interface() interface.launch()