import gradio as gr import torch import torch.nn as nn import torchvision.transforms as transforms from torchvision import models from transformers import BertTokenizer, BertModel import pandas as pd from datasets import load_dataset from torch.utils.data import DataLoader, Dataset from sklearn.preprocessing import LabelEncoder # Load dataset dataset = load_dataset('thefcraft/civitai-stable-diffusion-337k', split='train[:10000]') # Preprocess text data tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') class CustomDataset(Dataset): def init(self, dataset): self.dataset = dataset self.transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), ]) self.label_encoder = LabelEncoder() self.labels = self.label_encoder.fit_transform(dataset['Model']) def len(self): return len(self.dataset) def getitem(self, idx): image = self.transform(self.dataset[idx]['image']) text = tokenizer(self.dataset[idx]['prompt'], padding='max_length', truncation=True, return_tensors='pt') label = self.labels[idx] return image, text, label # Define CNN for image processing class ImageModel(nn.Module): def init(self): super(ImageModel, self).init() self.model = models.resnet18(pretrained=True) self.model.fc = nn.Linear(self.model.fc.in_features, 512) def forward(self, x): return self.model(x) # Define MLP for text processing class TextModel(nn.Module): def init(self): super(TextModel, self).init() self.bert = BertModel.from_pretrained('bert-base-uncased') self.fc = nn.Linear(768, 512) def forward(self, x): output = self.bert(x) return self.fc(output.pooler_output) # Combined model class CombinedModel(nn.Module): def init(self): super(CombinedModel, self).init() self.image_model = ImageModel() self.text_model = TextModel() self.fc = nn.Linear(1024, len(dataset['Model'])) def forward(self, image, text): image_features = self.image_model(image) text_features = self.text_model(text) combined = torch.cat((image_features, text_features), dim=1) return self.fc(combined) # Instantiate model model = CombinedModel() # Define predict function def predict(image): model.eval() with torch.no_grad(): image = transforms.ToTensor()(image).unsqueeze(0) image = transforms.Resize((224, 224))(image) text_input = tokenizer("Sample prompt", return_tensors='pt', padding=True, truncation=True) output = model(image, textinput) , indices = torch.topk(output, 5) recommended_models = [dataset['Model'][i] for i in indices[0]] return recommended_models # Set up Gradio interface interface = gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs=gr.Textbox(label="Recommended Models"), title="AI Image Model Recommender", description="Upload an AI-generated image to receive model recommendations.") # Launch the app interface.launch()