Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import clip | |
from PIL import Image | |
from torchvision import transforms, models | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import pandas as pd | |
from torch.utils.data import Dataset | |
import torch.nn as nn | |
import urllib.parse | |
import re | |
# Set device | |
if torch.backends.mps.is_available(): | |
device = torch.device("mps") | |
print("Utilizzo del dispositivo MPS") | |
else: | |
device = torch.device("cpu") | |
print("Utilizzo del dispositivo CPU") | |
# Dataset class | |
class ArtDataset(Dataset): | |
def __init__(self, csv_file, transform=None): | |
self.annotations = pd.read_csv(csv_file, delimiter=";") | |
self.transform = transform | |
self.label_map_style = {style: idx for idx, style in enumerate(self.annotations['genre'].unique())} | |
self.label_map_artist = {artist: idx for idx, artist in enumerate(self.annotations['artist'].unique())} | |
def __len__(self): | |
return len(self.annotations) | |
def __getitem__(self, idx): | |
img_path = self.annotations.iloc[idx]['filename'] | |
safe_img_path = urllib.parse.quote(img_path, safe="/:") | |
try: | |
image = Image.open(safe_img_path).convert("RGB") | |
style_label = self.label_map_style[self.annotations.iloc[idx]['genre']] | |
artist_label = self.label_map_artist[self.annotations.iloc[idx]['artist']] | |
if self.transform: | |
image = self.transform(image) | |
return image, (style_label, artist_label) | |
except (FileNotFoundError, OSError): | |
return None, (None, None) | |
# Image transformations | |
data_transforms = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]) | |
# Load dataset | |
csv_file = "classes.csv" | |
dataset = ArtDataset(csv_file=csv_file, transform=data_transforms) | |
# Define model | |
class DualOutputResNet(nn.Module): | |
def __init__(self, num_styles, num_artists): | |
super(DualOutputResNet, self).__init__() | |
self.backbone = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1) | |
num_features = self.backbone.fc.in_features | |
self.backbone.fc = nn.Identity() | |
self.fc_style = nn.Linear(num_features, num_styles) | |
self.fc_artist = nn.Linear(num_features, num_artists) | |
def forward(self, x): | |
features = self.backbone(x) | |
style_output = self.fc_style(features) | |
artist_output = self.fc_artist(features) | |
return style_output, artist_output | |
# Load pre-trained model | |
num_styles = len(dataset.label_map_style) | |
num_artists = len(dataset.label_map_artist) | |
model = DualOutputResNet(num_styles, num_artists).to(device) | |
model.load_state_dict(torch.load("dual_output_resnet.pth", map_location=device)) | |
model.eval() | |
# Load CLIP model | |
model_clip, preprocess_clip = clip.load("ViT-B/32", device=device) | |
model_clip.eval() | |
# Load GPT-Neo model | |
model_name = "EleutherAI/gpt-neo-1.3B" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model_gptneo = AutoModelForCausalLM.from_pretrained(model_name).to(device) | |
#Load dataset | |
dataset_desc = pd.read_csv("dataset_desc.csv", delimiter=';', usecols=['Artists', 'Style', 'Description']) | |
dataset_desc.columns = dataset_desc.columns.str.lower() | |
style_desc = pd.read_csv("style_desc.csv", delimiter=';') | |
style_desc.columns = style_desc.columns.str.lower() | |
# Function to enrich prompt | |
def enrich_prompt(artist, style): | |
artist_info = dataset_desc.loc[dataset_desc['artists'].str.lower() == artist.lower(), 'description'].values | |
style_info = style_desc.loc[style_desc['style'].str.lower() == style.lower(), 'description'].values | |
if len(style_info) == 0: | |
style_keywords = style.lower().split() | |
for keyword in style_keywords: | |
safe_keyword = re.escape(keyword) | |
partial_matches = style_desc[style_desc['style'].str.lower().str.contains(safe_keyword, na=False, regex=True)] | |
if not partial_matches.empty: | |
style_info = partial_matches['description'].values | |
break | |
artist_details = artist_info[0] if len(artist_info) > 0 else "" | |
style_details = style_info[0] if len(style_info) > 0 else "" | |
return f"{artist_details} This work exemplifies {style_details}." | |
# Function to generate description | |
def generate_description(image_path): | |
image = Image.open(image_path).convert("RGB") | |
image_resnet = data_transforms(image).unsqueeze(0).to(device) | |
# Predict style and artist | |
with torch.no_grad(): | |
outputs_style, outputs_artist = model(image_resnet) | |
_, predicted_style_idx = torch.max(outputs_style, 1) | |
_, predicted_artist_idx = torch.max(outputs_artist, 1) | |
idx_to_style = {v: k for k, v in dataset.label_map_style.items()} | |
idx_to_artist = {v: k for k, v in dataset.label_map_artist.items()} | |
predicted_style = idx_to_style[predicted_style_idx.item()] | |
predicted_artist = idx_to_artist[predicted_artist_idx.item()] | |
# Enrich prompt | |
enriched_prompt = enrich_prompt(predicted_artist, predicted_style) | |
full_prompt = ( | |
f"This is an artwork created by {predicted_artist} in the style of {predicted_style}. {enriched_prompt} " | |
"Describe its distinctive features, considering both the artist's techniques and the artistic style." | |
) | |
input_ids = tokenizer.encode(full_prompt, return_tensors="pt").to(device) | |
output = model_gptneo.generate( | |
input_ids=input_ids, | |
max_length=250, | |
num_return_sequences=1, | |
temperature=0.7, | |
top_p=0.9, | |
repetition_penalty=1.2 | |
) | |
description_text = tokenizer.decode(output[0], skip_special_tokens=True) | |
return predicted_style, predicted_artist, description_text | |
# Gradio interface | |
def predict(image): | |
style, artist, description = generate_description(image) | |
return f"**Predicted Style**: {style}\n\n**Predicted Artist**: {artist}\n\n**Description**:\n{description}" | |
iface = gr.Interface( | |
fn=predict, | |
inputs=gr.Image(type="filepath"), | |
outputs="text", | |
title="AI-Powered Artwork Recognition and Description", | |
description="Upload an image of artwork to predict its style and artist, and generate a description." | |
) | |
if __name__ == "__main__": | |
iface.launch() | |