Spaces:
Sleeping
Sleeping
# app.py | |
import gradio as gr | |
import torch | |
import numpy as np | |
from PIL import Image | |
from torch import nn | |
import torch.nn.functional as F | |
from datasets import load_dataset | |
from torch.utils.data import Dataset, DataLoader | |
import os | |
from tqdm import tqdm | |
from transformers import AutoProcessor, AutoModelForCausalLM | |
class SDDataset(Dataset): | |
def __init__(self, dataset, processor, model_to_idx, token_to_idx, max_samples=5000): | |
self.dataset = dataset.select(range(min(max_samples, len(dataset)))) | |
self.processor = processor | |
self.model_to_idx = model_to_idx | |
self.token_to_idx = token_to_idx | |
def __len__(self): | |
return len(self.dataset) | |
def __getitem__(self, idx): | |
item = self.dataset[idx] | |
# Process image | |
image = Image.open(item['image']) | |
image_inputs = self.processor(images=image, return_tensors="pt") | |
# Create model label | |
model_label = torch.zeros(len(self.model_to_idx)) | |
model_label[self.model_to_idx[item['model_name']]] = 1 | |
# Create prompt label (multi-hot encoding) | |
prompt_label = torch.zeros(len(self.token_to_idx)) | |
for token in item['prompt'].split(): | |
if token in self.token_to_idx: | |
prompt_label[self.token_to_idx[token]] = 1 | |
return image_inputs, model_label, prompt_label | |
class SDRecommenderModel(nn.Module): | |
def __init__(self, florence_model, num_models, vocab_size): | |
super().__init__() | |
self.florence = florence_model | |
hidden_size = 1024 # Florence-2-large hidden size | |
self.model_head = nn.Linear(hidden_size, num_models) | |
self.prompt_head = nn.Linear(hidden_size, vocab_size) | |
def forward(self, pixel_values): | |
# Get Florence embeddings | |
outputs = self.florence(pixel_values=pixel_values, output_hidden_states=True) | |
features = outputs.hidden_states[-1].mean(dim=1) # Use mean pooling of last hidden state | |
# Generate model and prompt recommendations | |
model_logits = self.model_head(features) | |
prompt_logits = self.prompt_head(features) | |
return model_logits, prompt_logits | |
class SDRecommender: | |
def __init__(self, max_samples=500): | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {self.device}") | |
# Load Florence model and processor | |
print("Loading Florence model and processor...") | |
self.processor = AutoProcessor.from_pretrained( | |
"microsoft/Florence-2-large", | |
trust_remote_code=True | |
) | |
self.florence = AutoModelForCausalLM.from_pretrained( | |
"microsoft/Florence-2-large", | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
trust_remote_code=True | |
).to(self.device) | |
# Load dataset | |
print("Loading dataset...") | |
self.dataset = load_dataset("thefcraft/civitai-stable-diffusion-337k", split="train") | |
self.dataset = self.dataset.select(range(min(max_samples, len(self.dataset)))) | |
print(f"Using {len(self.dataset)} samples from dataset") | |
# Create vocabularies for models and tokens | |
self.model_to_idx = self._create_model_vocab() | |
self.token_to_idx = self._create_prompt_vocab() | |
# Initialize the recommendation model | |
self.model = SDRecommenderModel( | |
self.florence, | |
len(self.model_to_idx), | |
len(self.token_to_idx) | |
).to(self.device) | |
# Load trained weights if available | |
if os.path.exists("recommender_model.pth"): | |
self.model.load_state_dict(torch.load("recommender_model.pth", map_location=self.device)) | |
print("Loaded trained model weights") | |
self.model.eval() | |
def _create_model_vocab(self): | |
print("Creating model vocabulary...") | |
models = set() | |
for item in self.dataset: | |
models.add(item["model_name"]) | |
return {model: idx for idx, model in enumerate(sorted(models))} | |
def _create_prompt_vocab(self): | |
print("Creating prompt vocabulary...") | |
tokens = set() | |
for item in self.dataset: | |
for token in item["prompt"].split(): | |
tokens.add(token) | |
return {token: idx for idx, token in enumerate(sorted(tokens))} | |
def train(self, num_epochs=5, batch_size=8, learning_rate=1e-4): | |
print("Starting training...") | |
# Create dataset and dataloader | |
train_dataset = SDDataset( | |
self.dataset, | |
self.processor, | |
self.model_to_idx, | |
self.token_to_idx | |
) | |
train_loader = DataLoader( | |
train_dataset, | |
batch_size=batch_size, | |
shuffle=True, | |
num_workers=2 | |
) | |
# Setup optimizer | |
optimizer = torch.optim.AdamW(self.model.parameters(), lr=learning_rate) | |
# Training loop | |
self.model.train() | |
for epoch in range(num_epochs): | |
total_loss = 0 | |
progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}") | |
for batch_idx, (images, model_labels, prompt_labels) in enumerate(progress_bar): | |
# Move everything to device | |
images = {k: v.to(self.device) for k, v in images.items()} | |
model_labels = model_labels.to(self.device) | |
prompt_labels = prompt_labels.to(self.device) | |
# Forward pass | |
model_logits, prompt_logits = self.model(images) | |
# Calculate loss | |
model_loss = F.cross_entropy(model_logits, model_labels) | |
prompt_loss = F.binary_cross_entropy_with_logits(prompt_logits, prompt_labels) | |
loss = model_loss + prompt_loss | |
# Backward pass | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
# Update progress | |
total_loss += loss.item() | |
progress_bar.set_postfix({"loss": total_loss / (batch_idx + 1)}) | |
# Save trained model | |
torch.save(self.model.state_dict(), "recommender_model.pth") | |
print("Training completed and model saved") | |
def get_recommendations(self, image): | |
# Convert uploaded image to PIL if needed | |
if not isinstance(image, Image.Image): | |
image = Image.open(image) | |
# Process image | |
inputs = self.processor(images=image, return_tensors="pt") | |
inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
# Get model predictions | |
self.model.eval() | |
with torch.no_grad(): | |
model_logits, prompt_logits = self.model(inputs) | |
# Get top 5 model recommendations | |
model_probs = F.softmax(model_logits, dim=-1) | |
top_models = torch.topk(model_probs, k=5) | |
model_recommendations = [ | |
(list(self.model_to_idx.keys())[idx.item()], prob.item()) | |
for prob, idx in zip(top_models.values[0], top_models.indices[0]) | |
] | |
# Get top tokens for prompt recommendations | |
prompt_probs = F.softmax(prompt_logits, dim=-1) | |
top_tokens = torch.topk(prompt_probs, k=20) | |
recommended_tokens = [ | |
list(self.token_to_idx.keys())[idx.item()] | |
for idx in top_tokens.indices[0] | |
] | |
# Create 5 prompt combinations | |
prompt_recommendations = [ | |
" ".join(np.random.choice(recommended_tokens, size=8, replace=False)) | |
for _ in range(5) | |
] | |
return ( | |
"\n".join(f"{model} (confidence: {conf:.2f})" for model, conf in model_recommendations), | |
"\n".join(prompt_recommendations) | |
) | |
# Create Gradio interface | |
def create_interface(): | |
recommender = SDRecommender(max_samples=5000) | |
# Train the model if no trained weights exist | |
if not os.path.exists("recommender_model.pth"): | |
recommender.train() | |
def process_image(image): | |
model_recs, prompt_recs = recommender.get_recommendations(image) | |
return model_recs, prompt_recs | |
interface = gr.Interface( | |
fn=process_image, | |
inputs=gr.Image(type="pil"), | |
outputs=[ | |
gr.Textbox(label="Recommended Models"), | |
gr.Textbox(label="Recommended Prompts") | |
], | |
title="Stable Diffusion Model & Prompt Recommender", | |
description="Upload an AI-generated image to get model and prompt recommendations", | |
) | |
return interface | |
# Launch the interface | |
if __name__ == "__main__": | |
interface = create_interface() | |
interface.launch() |