CNN_MLP / app.py
bgaspra's picture
Update app.py
67797ef verified
raw
history blame
3.48 kB
import gradio as gr
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import models
from transformers import BertTokenizer, BertModel
import pandas as pd
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
from sklearn.preprocessing import LabelEncoder
# Load dataset and filter out null/none values
dataset = load_dataset('thefcraft/civitai-stable-diffusion-337k', split='train[:10000]')
# Filter out entries where Model is None or empty
dataset = dataset.filter(lambda example: example['Model'] is not None and example['Model'].strip() != '')
# Preprocess text data
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
class CustomDataset(Dataset):
def __init__(self, dataset):
self.dataset = dataset
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
self.label_encoder = LabelEncoder()
self.labels = self.label_encoder.fit_transform(dataset['Model'])
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
image = self.transform(self.dataset[idx]['image'])
text = tokenizer(
self.dataset[idx]['prompt'],
padding='max_length',
truncation=True,
return_tensors='pt'
)
label = self.labels[idx]
return image, text, label
# Define CNN for image processing
class ImageModel(nn.Module):
def __init__(self):
super(ImageModel, self).__init__()
self.model = models.resnet18(pretrained=True)
self.model.fc = nn.Linear(self.model.fc.in_features, 512)
def forward(self, x):
return self.model(x)
# Define MLP for text processing
class TextModel(nn.Module):
def __init__(self):
super(TextModel, self).__init__()
self.bert = BertModel.from_pretrained('bert-base-uncased')
self.fc = nn.Linear(768, 512)
def forward(self, x):
output = self.bert(**x)
return self.fc(output.pooler_output)
# Combined model
class CombinedModel(nn.Module):
def __init__(self):
super(CombinedModel, self).__init__()
self.image_model = ImageModel()
self.text_model = TextModel()
self.fc = nn.Linear(1024, len(dataset['Model']))
def forward(self, image, text):
image_features = self.image_model(image)
text_features = self.text_model(text)
combined = torch.cat((image_features, text_features), dim=1)
return self.fc(combined)
# Instantiate model
model = CombinedModel()
# Define predict function
def predict(image):
model.eval()
with torch.no_grad():
image = transforms.ToTensor()(image).unsqueeze(0)
image = transforms.Resize((224, 224))(image)
text_input = tokenizer(
"Sample prompt",
return_tensors='pt',
padding=True,
truncation=True
)
output = model(image, text_input)
_, indices = torch.topk(output, 5)
recommended_models = [dataset['Model'][i] for i in indices[0]]
return recommended_models
# Set up Gradio interface
interface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Textbox(label="Recommended Models"),
title="AI Image Model Recommender",
description="Upload an AI-generated image to receive model recommendations."
)
# Launch the app
interface.launch()