Rec_Sys_Flo2 / app.py
bgaspra's picture
Update app.py
bd1b634 verified
raw
history blame
9.03 kB
# app.py
import gradio as gr
import torch
import numpy as np
from PIL import Image
from torch import nn
import torch.nn.functional as F
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
import os
from tqdm import tqdm
from transformers import AutoProcessor, AutoModelForCausalLM
class SDDataset(Dataset):
def __init__(self, dataset, processor, model_to_idx, token_to_idx, max_samples=5000):
self.dataset = dataset.select(range(min(max_samples, len(dataset))))
self.processor = processor
self.model_to_idx = model_to_idx
self.token_to_idx = token_to_idx
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
item = self.dataset[idx]
# Process image
image = Image.open(item['image'])
image_inputs = self.processor(images=image, return_tensors="pt")
# Create model label
model_label = torch.zeros(len(self.model_to_idx))
model_label[self.model_to_idx[item['model_name']]] = 1
# Create prompt label (multi-hot encoding)
prompt_label = torch.zeros(len(self.token_to_idx))
for token in item['prompt'].split():
if token in self.token_to_idx:
prompt_label[self.token_to_idx[token]] = 1
return image_inputs, model_label, prompt_label
class SDRecommenderModel(nn.Module):
def __init__(self, florence_model, num_models, vocab_size):
super().__init__()
self.florence = florence_model
hidden_size = 1024 # Florence-2-large hidden size
self.model_head = nn.Linear(hidden_size, num_models)
self.prompt_head = nn.Linear(hidden_size, vocab_size)
def forward(self, pixel_values):
# Get Florence embeddings
outputs = self.florence(pixel_values=pixel_values, output_hidden_states=True)
features = outputs.hidden_states[-1].mean(dim=1) # Use mean pooling of last hidden state
# Generate model and prompt recommendations
model_logits = self.model_head(features)
prompt_logits = self.prompt_head(features)
return model_logits, prompt_logits
class SDRecommender:
def __init__(self, max_samples=500):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {self.device}")
# Load Florence model and processor
print("Loading Florence model and processor...")
self.processor = AutoProcessor.from_pretrained(
"microsoft/Florence-2-large",
trust_remote_code=True
)
self.florence = AutoModelForCausalLM.from_pretrained(
"microsoft/Florence-2-large",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
trust_remote_code=True
).to(self.device)
# Load dataset
print("Loading dataset...")
self.dataset = load_dataset("thefcraft/civitai-stable-diffusion-337k", split="train")
self.dataset = self.dataset.select(range(min(max_samples, len(self.dataset))))
print(f"Using {len(self.dataset)} samples from dataset")
# Create vocabularies for models and tokens
self.model_to_idx = self._create_model_vocab()
self.token_to_idx = self._create_prompt_vocab()
# Initialize the recommendation model
self.model = SDRecommenderModel(
self.florence,
len(self.model_to_idx),
len(self.token_to_idx)
).to(self.device)
# Load trained weights if available
if os.path.exists("recommender_model.pth"):
self.model.load_state_dict(torch.load("recommender_model.pth", map_location=self.device))
print("Loaded trained model weights")
self.model.eval()
def _create_model_vocab(self):
print("Creating model vocabulary...")
models = set()
for item in self.dataset:
models.add(item["model_name"])
return {model: idx for idx, model in enumerate(sorted(models))}
def _create_prompt_vocab(self):
print("Creating prompt vocabulary...")
tokens = set()
for item in self.dataset:
for token in item["prompt"].split():
tokens.add(token)
return {token: idx for idx, token in enumerate(sorted(tokens))}
def train(self, num_epochs=5, batch_size=8, learning_rate=1e-4):
print("Starting training...")
# Create dataset and dataloader
train_dataset = SDDataset(
self.dataset,
self.processor,
self.model_to_idx,
self.token_to_idx
)
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=2
)
# Setup optimizer
optimizer = torch.optim.AdamW(self.model.parameters(), lr=learning_rate)
# Training loop
self.model.train()
for epoch in range(num_epochs):
total_loss = 0
progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
for batch_idx, (images, model_labels, prompt_labels) in enumerate(progress_bar):
# Move everything to device
images = {k: v.to(self.device) for k, v in images.items()}
model_labels = model_labels.to(self.device)
prompt_labels = prompt_labels.to(self.device)
# Forward pass
model_logits, prompt_logits = self.model(images)
# Calculate loss
model_loss = F.cross_entropy(model_logits, model_labels)
prompt_loss = F.binary_cross_entropy_with_logits(prompt_logits, prompt_labels)
loss = model_loss + prompt_loss
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Update progress
total_loss += loss.item()
progress_bar.set_postfix({"loss": total_loss / (batch_idx + 1)})
# Save trained model
torch.save(self.model.state_dict(), "recommender_model.pth")
print("Training completed and model saved")
def get_recommendations(self, image):
# Convert uploaded image to PIL if needed
if not isinstance(image, Image.Image):
image = Image.open(image)
# Process image
inputs = self.processor(images=image, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Get model predictions
self.model.eval()
with torch.no_grad():
model_logits, prompt_logits = self.model(inputs)
# Get top 5 model recommendations
model_probs = F.softmax(model_logits, dim=-1)
top_models = torch.topk(model_probs, k=5)
model_recommendations = [
(list(self.model_to_idx.keys())[idx.item()], prob.item())
for prob, idx in zip(top_models.values[0], top_models.indices[0])
]
# Get top tokens for prompt recommendations
prompt_probs = F.softmax(prompt_logits, dim=-1)
top_tokens = torch.topk(prompt_probs, k=20)
recommended_tokens = [
list(self.token_to_idx.keys())[idx.item()]
for idx in top_tokens.indices[0]
]
# Create 5 prompt combinations
prompt_recommendations = [
" ".join(np.random.choice(recommended_tokens, size=8, replace=False))
for _ in range(5)
]
return (
"\n".join(f"{model} (confidence: {conf:.2f})" for model, conf in model_recommendations),
"\n".join(prompt_recommendations)
)
# Create Gradio interface
def create_interface():
recommender = SDRecommender(max_samples=5000)
# Train the model if no trained weights exist
if not os.path.exists("recommender_model.pth"):
recommender.train()
def process_image(image):
model_recs, prompt_recs = recommender.get_recommendations(image)
return model_recs, prompt_recs
interface = gr.Interface(
fn=process_image,
inputs=gr.Image(type="pil"),
outputs=[
gr.Textbox(label="Recommended Models"),
gr.Textbox(label="Recommended Prompts")
],
title="Stable Diffusion Model & Prompt Recommender",
description="Upload an AI-generated image to get model and prompt recommendations",
)
return interface
# Launch the interface
if __name__ == "__main__":
interface = create_interface()
interface.launch()