|
import gradio as gr |
|
import torch |
|
import torch.nn as nn |
|
import torchvision.transforms as transforms |
|
from torchvision import models |
|
from transformers import BertTokenizer, BertModel |
|
import pandas as pd |
|
from datasets import load_dataset |
|
from torch.utils.data import DataLoader, Dataset |
|
from sklearn.preprocessing import LabelEncoder |
|
|
|
|
|
dataset = load_dataset('thefcraft/civitai-stable-diffusion-337k', split='train[:10000]') |
|
|
|
dataset = dataset.filter(lambda example: example['Model'] is not None and example['Model'].strip() != '') |
|
|
|
|
|
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
|
|
|
class CustomDataset(Dataset): |
|
def __init__(self, dataset): |
|
self.dataset = dataset |
|
self.transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
]) |
|
self.label_encoder = LabelEncoder() |
|
self.labels = self.label_encoder.fit_transform(dataset['Model']) |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
def __getitem__(self, idx): |
|
image = self.transform(self.dataset[idx]['image']) |
|
text = tokenizer( |
|
self.dataset[idx]['prompt'], |
|
padding='max_length', |
|
truncation=True, |
|
return_tensors='pt' |
|
) |
|
label = self.labels[idx] |
|
return image, text, label |
|
|
|
|
|
class ImageModel(nn.Module): |
|
def __init__(self): |
|
super(ImageModel, self).__init__() |
|
self.model = models.resnet18(pretrained=True) |
|
self.model.fc = nn.Linear(self.model.fc.in_features, 512) |
|
|
|
def forward(self, x): |
|
return self.model(x) |
|
|
|
|
|
class TextModel(nn.Module): |
|
def __init__(self): |
|
super(TextModel, self).__init__() |
|
self.bert = BertModel.from_pretrained('bert-base-uncased') |
|
self.fc = nn.Linear(768, 512) |
|
|
|
def forward(self, x): |
|
output = self.bert(**x) |
|
return self.fc(output.pooler_output) |
|
|
|
|
|
class CombinedModel(nn.Module): |
|
def __init__(self): |
|
super(CombinedModel, self).__init__() |
|
self.image_model = ImageModel() |
|
self.text_model = TextModel() |
|
self.fc = nn.Linear(1024, len(dataset['Model'])) |
|
|
|
def forward(self, image, text): |
|
image_features = self.image_model(image) |
|
text_features = self.text_model(text) |
|
combined = torch.cat((image_features, text_features), dim=1) |
|
return self.fc(combined) |
|
|
|
|
|
model = CombinedModel() |
|
|
|
|
|
def predict(image): |
|
model.eval() |
|
with torch.no_grad(): |
|
image = transforms.ToTensor()(image).unsqueeze(0) |
|
image = transforms.Resize((224, 224))(image) |
|
text_input = tokenizer( |
|
"Sample prompt", |
|
return_tensors='pt', |
|
padding=True, |
|
truncation=True |
|
) |
|
output = model(image, text_input) |
|
_, indices = torch.topk(output, 5) |
|
recommended_models = [dataset['Model'][i] for i in indices[0]] |
|
return recommended_models |
|
|
|
|
|
interface = gr.Interface( |
|
fn=predict, |
|
inputs=gr.Image(type="pil"), |
|
outputs=gr.Textbox(label="Recommended Models"), |
|
title="AI Image Model Recommender", |
|
description="Upload an AI-generated image to receive model recommendations." |
|
) |
|
|
|
|
|
interface.launch() |