import torch import clip from PIL import Image from torchvision import transforms, models from transformers import AutoModelForCausalLM, AutoTokenizer import pandas as pd import random import urllib.parse import torch.nn as nn from sklearn.metrics import classification_report from torch.optim.lr_scheduler import ReduceLROnPlateau import gradio as gr # Device setup device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") print(f"Using device: {device}") # Data transformation data_transforms = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # Load datasets for enriched prompts dataset_desc = pd.read_csv("dataset_desc.csv", delimiter=';', usecols=['Artists', 'Style', 'Description']) dataset_desc.columns = dataset_desc.columns.str.lower() style_desc = pd.read_csv("style_desc.csv", delimiter=';') # CSV containing style-specific descriptions style_desc.columns = style_desc.columns.str.lower() # Function to enrich prompts with custom data def enrich_prompt(artist, style): artist_info = dataset_desc.loc[dataset_desc['artists'] == artist, 'description'].values style_info = style_desc.loc[style_desc['style'] == style, 'description'].values artist_details = artist_info[0] if len(artist_info) > 0 else "Details about the artist are not available." style_details = style_info[0] if len(style_info) > 0 else "Details about the style are not available." return f"{artist_details} This work exemplifies {style_details}." # Custom dataset for ResNet18 class ArtDataset: def __init__(self, csv_file): self.annotations = pd.read_csv(csv_file) self.train_data = self.annotations[self.annotations['subset'] == 'train'] self.test_data = self.annotations[self.annotations['subset'] == 'test'] self.label_map_style = {style: idx for idx, style in enumerate(self.annotations['genre'].unique())} self.label_map_artist = {artist: idx for idx, artist in enumerate(self.annotations['artist'].unique())} def get_style_and_artist_mappings(self): return self.label_map_style, self.label_map_artist def get_train_test_split(self): return self.train_data, self.test_data # DualOutputResNet model with Dropout class DualOutputResNet(nn.Module): def __init__(self, num_styles, num_artists, dropout_rate=0.5): super(DualOutputResNet, self).__init__() self.backbone = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1) num_features = self.backbone.fc.in_features self.backbone.fc = nn.Identity() self.dropout = nn.Dropout(dropout_rate) self.fc_style = nn.Linear(num_features, num_styles) self.fc_artist = nn.Linear(num_features, num_artists) def forward(self, x): features = self.backbone(x) features = self.dropout(features) style_output = self.fc_style(features) artist_output = self.fc_artist(features) return style_output, artist_output # Load dataset csv_file = "cleaned_classes.csv" dataset = ArtDataset(csv_file) label_map_style, label_map_artist = dataset.get_style_and_artist_mappings() train_data, test_data = dataset.get_train_test_split() num_styles = len(label_map_style) num_artists = len(label_map_artist) # Model setup model_resnet = DualOutputResNet(num_styles, num_artists).to(device) optimizer = torch.optim.Adam(model_resnet.parameters(), lr=0.001, weight_decay=1e-5) scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True) # Load GPT-Neo and CLIP model_clip, preprocess_clip = clip.load("ViT-B/32", device=device) model_clip.eval() model_name = "EleutherAI/gpt-neo-1.3B" tokenizer = AutoTokenizer.from_pretrained(model_name) model_gptneo = AutoModelForCausalLM.from_pretrained(model_name).to(device) # Generate prediction using ResNet and CLIP def predict(image_path): image = Image.open(image_path).convert("RGB") image_tensor = data_transforms(image).unsqueeze(0).to(device) # Predict with ResNet style_logits, artist_logits = model_resnet(image_tensor) style_idx = torch.argmax(style_logits, dim=1).item() artist_idx = torch.argmax(artist_logits, dim=1).item() predicted_style = list(label_map_style.keys())[list(label_map_style.values()).index(style_idx)] predicted_artist = list(label_map_artist.keys())[list(label_map_artist.values()).index(artist_idx)] # Enrich prompt with additional information prompt = enrich_prompt(predicted_artist, predicted_style) # Generate text description using GPT-Neo input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) output = model_gptneo.generate(input_ids, max_length=350, num_return_sequences=1) description = tokenizer.decode(output[0], skip_special_tokens=True) return predicted_style, predicted_artist, description # Gradio interface def gradio_interface(image): predicted_style, predicted_artist, description = predict(image) return f"Predicted Style: {predicted_style}\nPredicted Artist: {predicted_artist}\n\nDescription:\n{description}" iface = gr.Interface( fn=gradio_interface, inputs=gr.Image(type="filepath"), outputs="text", title="AI Artwork Analysis", description="Upload an image to predict its artistic style and creator, and generate a detailed description." ) if __name__ == "__main__": iface.launch()