Spaces:
Sleeping
Sleeping
File size: 6,288 Bytes
43146c8 e137b2e 43146c8 18c0502 43146c8 8d2c3ff 43146c8 8d2c3ff 43146c8 5a0ed9f 43146c8 abc4ab8 43146c8 c16b43a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
import gradio as gr
import torch
import clip
from PIL import Image
from torchvision import transforms, models
from transformers import AutoModelForCausalLM, AutoTokenizer
import pandas as pd
from torch.utils.data import Dataset
import torch.nn as nn
import urllib.parse
import re
# Set device
if torch.backends.mps.is_available():
device = torch.device("mps")
print("Utilizzo del dispositivo MPS")
else:
device = torch.device("cpu")
print("Utilizzo del dispositivo CPU")
# Dataset class
class ArtDataset(Dataset):
def __init__(self, csv_file, transform=None):
self.annotations = pd.read_csv(csv_file, delimiter=";")
self.transform = transform
self.label_map_style = {style: idx for idx, style in enumerate(self.annotations['genre'].unique())}
self.label_map_artist = {artist: idx for idx, artist in enumerate(self.annotations['artist'].unique())}
def __len__(self):
return len(self.annotations)
def __getitem__(self, idx):
img_path = self.annotations.iloc[idx]['filename']
safe_img_path = urllib.parse.quote(img_path, safe="/:")
try:
image = Image.open(safe_img_path).convert("RGB")
style_label = self.label_map_style[self.annotations.iloc[idx]['genre']]
artist_label = self.label_map_artist[self.annotations.iloc[idx]['artist']]
if self.transform:
image = self.transform(image)
return image, (style_label, artist_label)
except (FileNotFoundError, OSError):
return None, (None, None)
# Image transformations
data_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Load dataset
csv_file = "classes.csv"
dataset = ArtDataset(csv_file=csv_file, transform=data_transforms)
# Define model
class DualOutputResNet(nn.Module):
def __init__(self, num_styles, num_artists):
super(DualOutputResNet, self).__init__()
self.backbone = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
num_features = self.backbone.fc.in_features
self.backbone.fc = nn.Identity()
self.fc_style = nn.Linear(num_features, num_styles)
self.fc_artist = nn.Linear(num_features, num_artists)
def forward(self, x):
features = self.backbone(x)
style_output = self.fc_style(features)
artist_output = self.fc_artist(features)
return style_output, artist_output
# Load pre-trained model
num_styles = len(dataset.label_map_style)
num_artists = len(dataset.label_map_artist)
model = DualOutputResNet(num_styles, num_artists).to(device)
model.load_state_dict(torch.load("dual_output_resnet.pth", map_location=device))
model.eval()
# Load CLIP model
model_clip, preprocess_clip = clip.load("ViT-B/32", device=device)
model_clip.eval()
# Load GPT-Neo model
model_name = "EleutherAI/gpt-neo-1.3B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model_gptneo = AutoModelForCausalLM.from_pretrained(model_name).to(device)
#Load dataset
dataset_desc = pd.read_csv("dataset_desc.csv", delimiter=';', usecols=['Artists', 'Style', 'Description'])
dataset_desc.columns = dataset_desc.columns.str.lower()
style_desc = pd.read_csv("style_desc.csv", delimiter=';')
style_desc.columns = style_desc.columns.str.lower()
# Function to enrich prompt
def enrich_prompt(artist, style):
artist_info = dataset_desc.loc[dataset_desc['artists'].str.lower() == artist.lower(), 'description'].values
style_info = style_desc.loc[style_desc['style'].str.lower() == style.lower(), 'description'].values
if len(style_info) == 0:
style_keywords = style.lower().split()
for keyword in style_keywords:
safe_keyword = re.escape(keyword)
partial_matches = style_desc[style_desc['style'].str.lower().str.contains(safe_keyword, na=False, regex=True)]
if not partial_matches.empty:
style_info = partial_matches['description'].values
break
artist_details = artist_info[0] if len(artist_info) > 0 else ""
style_details = style_info[0] if len(style_info) > 0 else ""
return f"{artist_details} This work exemplifies {style_details}."
# Function to generate description
def generate_description(image_path):
image = Image.open(image_path).convert("RGB")
image_resnet = data_transforms(image).unsqueeze(0).to(device)
# Predict style and artist
with torch.no_grad():
outputs_style, outputs_artist = model(image_resnet)
_, predicted_style_idx = torch.max(outputs_style, 1)
_, predicted_artist_idx = torch.max(outputs_artist, 1)
idx_to_style = {v: k for k, v in dataset.label_map_style.items()}
idx_to_artist = {v: k for k, v in dataset.label_map_artist.items()}
predicted_style = idx_to_style[predicted_style_idx.item()]
predicted_artist = idx_to_artist[predicted_artist_idx.item()]
# Enrich prompt
enriched_prompt = enrich_prompt(predicted_artist, predicted_style)
full_prompt = (
f"This is an artwork created by {predicted_artist} in the style of {predicted_style}. {enriched_prompt} "
"Describe its distinctive features, considering both the artist's techniques and the artistic style."
)
input_ids = tokenizer.encode(full_prompt, return_tensors="pt").to(device)
output = model_gptneo.generate(
input_ids=input_ids,
max_length=250,
num_return_sequences=1,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.2
)
description_text = tokenizer.decode(output[0], skip_special_tokens=True)
return predicted_style, predicted_artist, description_text
# Gradio interface
def predict(image):
style, artist, description = generate_description(image)
return f"**Predicted Style**: {style}\n\n**Predicted Artist**: {artist}\n\n**Description**:\n{description}"
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="filepath"),
outputs="text",
title="AI-Powered Artwork Recognition and Description",
description="Upload an image of artwork to predict its style and artist, and generate a description."
)
if __name__ == "__main__":
iface.launch()
|