import gradio as gr import torch import torch.nn as nn import torchvision.transforms as transforms from torchvision import models from transformers import BertTokenizer, BertModel import pandas as pd from datasets import load_dataset from torch.utils.data import DataLoader, Dataset from sklearn.preprocessing import LabelEncoder # Load dataset and filter out null/none values dataset = load_dataset('thefcraft/civitai-stable-diffusion-337k', split='train[:10000]') # Filter out entries where Model is None or empty dataset = dataset.filter(lambda example: example['Model'] is not None and example['Model'].strip() != '') # Preprocess text data tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') class CustomDataset(Dataset): def __init__(self, dataset): self.dataset = dataset self.transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), ]) self.label_encoder = LabelEncoder() self.labels = self.label_encoder.fit_transform(dataset['Model']) def __len__(self): return len(self.dataset) def __getitem__(self, idx): image = self.transform(self.dataset[idx]['image']) text = tokenizer( self.dataset[idx]['prompt'], padding='max_length', truncation=True, return_tensors='pt' ) label = self.labels[idx] return image, text, label # Define CNN for image processing class ImageModel(nn.Module): def __init__(self): super(ImageModel, self).__init__() self.model = models.resnet18(pretrained=True) self.model.fc = nn.Linear(self.model.fc.in_features, 512) def forward(self, x): return self.model(x) # Define MLP for text processing class TextModel(nn.Module): def __init__(self): super(TextModel, self).__init__() self.bert = BertModel.from_pretrained('bert-base-uncased') self.fc = nn.Linear(768, 512) def forward(self, x): output = self.bert(**x) return self.fc(output.pooler_output) # Combined model class CombinedModel(nn.Module): def __init__(self): super(CombinedModel, self).__init__() self.image_model = ImageModel() self.text_model = TextModel() self.fc = nn.Linear(1024, len(dataset['Model'])) def forward(self, image, text): image_features = self.image_model(image) text_features = self.text_model(text) combined = torch.cat((image_features, text_features), dim=1) return self.fc(combined) # Instantiate model model = CombinedModel() # Define predict function def predict(image): model.eval() with torch.no_grad(): image = transforms.ToTensor()(image).unsqueeze(0) image = transforms.Resize((224, 224))(image) text_input = tokenizer( "Sample prompt", return_tensors='pt', padding=True, truncation=True ) output = model(image, text_input) _, indices = torch.topk(output, 5) recommended_models = [dataset['Model'][i] for i in indices[0]] return recommended_models # Set up Gradio interface interface = gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs=gr.Textbox(label="Recommended Models"), title="AI Image Model Recommender", description="Upload an AI-generated image to receive model recommendations." ) # Launch the app interface.launch()