import gradio as gr import torch import clip from PIL import Image from torchvision import transforms, models from transformers import AutoModelForCausalLM, AutoTokenizer import pandas as pd from torch.utils.data import Dataset import torch.nn as nn import urllib.parse import re # Set device if torch.backends.mps.is_available(): device = torch.device("mps") print("Utilizzo del dispositivo MPS") else: device = torch.device("cpu") print("Utilizzo del dispositivo CPU") # Dataset class class ArtDataset(Dataset): def __init__(self, csv_file, transform=None): self.annotations = pd.read_csv(csv_file, delimiter=";") self.transform = transform self.label_map_style = {style: idx for idx, style in enumerate(self.annotations['genre'].unique())} self.label_map_artist = {artist: idx for idx, artist in enumerate(self.annotations['artist'].unique())} def __len__(self): return len(self.annotations) def __getitem__(self, idx): img_path = self.annotations.iloc[idx]['filename'] safe_img_path = urllib.parse.quote(img_path, safe="/:") try: image = Image.open(safe_img_path).convert("RGB") style_label = self.label_map_style[self.annotations.iloc[idx]['genre']] artist_label = self.label_map_artist[self.annotations.iloc[idx]['artist']] if self.transform: image = self.transform(image) return image, (style_label, artist_label) except (FileNotFoundError, OSError): return None, (None, None) # Image transformations data_transforms = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # Load dataset csv_file = "classes.csv" dataset = ArtDataset(csv_file=csv_file, transform=data_transforms) # Define model class DualOutputResNet(nn.Module): def __init__(self, num_styles, num_artists): super(DualOutputResNet, self).__init__() self.backbone = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1) num_features = self.backbone.fc.in_features self.backbone.fc = nn.Identity() self.fc_style = nn.Linear(num_features, num_styles) self.fc_artist = nn.Linear(num_features, num_artists) def forward(self, x): features = self.backbone(x) style_output = self.fc_style(features) artist_output = self.fc_artist(features) return style_output, artist_output # Load pre-trained model num_styles = len(dataset.label_map_style) num_artists = len(dataset.label_map_artist) model = DualOutputResNet(num_styles, num_artists).to(device) model.load_state_dict(torch.load("dual_output_resnet.pth", map_location=device)) model.eval() # Load CLIP model model_clip, preprocess_clip = clip.load("ViT-B/32", device=device) model_clip.eval() # Load GPT-Neo model model_name = "EleutherAI/gpt-neo-1.3B" tokenizer = AutoTokenizer.from_pretrained(model_name) model_gptneo = AutoModelForCausalLM.from_pretrained(model_name).to(device) #Load dataset dataset_desc = pd.read_csv("dataset_desc.csv", delimiter=';', usecols=['Artists', 'Style', 'Description']) dataset_desc.columns = dataset_desc.columns.str.lower() style_desc = pd.read_csv("style_desc.csv", delimiter=';') style_desc.columns = style_desc.columns.str.lower() # Function to enrich prompt def enrich_prompt(artist, style): artist_info = dataset_desc.loc[dataset_desc['artists'].str.lower() == artist.lower(), 'description'].values style_info = style_desc.loc[style_desc['style'].str.lower() == style.lower(), 'description'].values if len(style_info) == 0: style_keywords = style.lower().split() for keyword in style_keywords: safe_keyword = re.escape(keyword) partial_matches = style_desc[style_desc['style'].str.lower().str.contains(safe_keyword, na=False, regex=True)] if not partial_matches.empty: style_info = partial_matches['description'].values break artist_details = artist_info[0] if len(artist_info) > 0 else "" style_details = style_info[0] if len(style_info) > 0 else "" return f"{artist_details} This work exemplifies {style_details}." # Function to generate description def generate_description(image_path): image = Image.open(image_path).convert("RGB") image_resnet = data_transforms(image).unsqueeze(0).to(device) # Predict style and artist with torch.no_grad(): outputs_style, outputs_artist = model(image_resnet) _, predicted_style_idx = torch.max(outputs_style, 1) _, predicted_artist_idx = torch.max(outputs_artist, 1) idx_to_style = {v: k for k, v in dataset.label_map_style.items()} idx_to_artist = {v: k for k, v in dataset.label_map_artist.items()} predicted_style = idx_to_style[predicted_style_idx.item()] predicted_artist = idx_to_artist[predicted_artist_idx.item()] # Enrich prompt enriched_prompt = enrich_prompt(predicted_artist, predicted_style) full_prompt = ( f"This is an artwork created by {predicted_artist} in the style of {predicted_style}. {enriched_prompt} " "Describe its distinctive features, considering both the artist's techniques and the artistic style." ) input_ids = tokenizer.encode(full_prompt, return_tensors="pt").to(device) output = model_gptneo.generate( input_ids=input_ids, max_length=250, num_return_sequences=1, temperature=0.7, top_p=0.9, repetition_penalty=1.2 ) description_text = tokenizer.decode(output[0], skip_special_tokens=True) return predicted_style, predicted_artist, description_text # Gradio interface def predict(image): style, artist, description = generate_description(image) return f"**Predicted Style**: {style}\n\n**Predicted Artist**: {artist}\n\n**Description**:\n{description}" iface = gr.Interface( fn=predict, inputs=gr.Image(type="filepath"), outputs="text", title="AI-Powered Artwork Recognition and Description", description="Upload an image of artwork to predict its style and artist, and generate a description." ) if __name__ == "__main__": iface.launch()