File size: 6,288 Bytes
43146c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e137b2e
43146c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18c0502
43146c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8d2c3ff
 
 
 
 
 
 
 
43146c8
 
8d2c3ff
43146c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a0ed9f
43146c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
abc4ab8
43146c8
 
 
 
 
 
c16b43a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import gradio as gr
import torch
import clip
from PIL import Image
from torchvision import transforms, models
from transformers import AutoModelForCausalLM, AutoTokenizer
import pandas as pd
from torch.utils.data import Dataset
import torch.nn as nn
import urllib.parse
import re

# Set device
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Utilizzo del dispositivo MPS")
else:
    device = torch.device("cpu")
    print("Utilizzo del dispositivo CPU")

# Dataset class
class ArtDataset(Dataset):
    def __init__(self, csv_file, transform=None):
        self.annotations = pd.read_csv(csv_file, delimiter=";")
        self.transform = transform
        self.label_map_style = {style: idx for idx, style in enumerate(self.annotations['genre'].unique())}
        self.label_map_artist = {artist: idx for idx, artist in enumerate(self.annotations['artist'].unique())}

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        img_path = self.annotations.iloc[idx]['filename']
        safe_img_path = urllib.parse.quote(img_path, safe="/:")
        try:
            image = Image.open(safe_img_path).convert("RGB")
            style_label = self.label_map_style[self.annotations.iloc[idx]['genre']]
            artist_label = self.label_map_artist[self.annotations.iloc[idx]['artist']]
            if self.transform:
                image = self.transform(image)
            return image, (style_label, artist_label)
        except (FileNotFoundError, OSError):
            return None, (None, None)

# Image transformations
data_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Load dataset
csv_file = "classes.csv"
dataset = ArtDataset(csv_file=csv_file, transform=data_transforms)

# Define model
class DualOutputResNet(nn.Module):
    def __init__(self, num_styles, num_artists):
        super(DualOutputResNet, self).__init__()
        self.backbone = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        num_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()
        self.fc_style = nn.Linear(num_features, num_styles)
        self.fc_artist = nn.Linear(num_features, num_artists)

    def forward(self, x):
        features = self.backbone(x)
        style_output = self.fc_style(features)
        artist_output = self.fc_artist(features)
        return style_output, artist_output

# Load pre-trained model
num_styles = len(dataset.label_map_style)
num_artists = len(dataset.label_map_artist)
model = DualOutputResNet(num_styles, num_artists).to(device)
model.load_state_dict(torch.load("dual_output_resnet.pth", map_location=device))
model.eval()

# Load CLIP model
model_clip, preprocess_clip = clip.load("ViT-B/32", device=device)
model_clip.eval()

# Load GPT-Neo model
model_name = "EleutherAI/gpt-neo-1.3B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model_gptneo = AutoModelForCausalLM.from_pretrained(model_name).to(device)


#Load dataset

dataset_desc = pd.read_csv("dataset_desc.csv", delimiter=';', usecols=['Artists', 'Style', 'Description'])
dataset_desc.columns = dataset_desc.columns.str.lower()
style_desc = pd.read_csv("style_desc.csv", delimiter=';')  
style_desc.columns = style_desc.columns.str.lower()

# Function to enrich prompt
def enrich_prompt(artist, style):
    
    artist_info = dataset_desc.loc[dataset_desc['artists'].str.lower() == artist.lower(), 'description'].values
    style_info = style_desc.loc[style_desc['style'].str.lower() == style.lower(), 'description'].values

    if len(style_info) == 0:
        style_keywords = style.lower().split()
        for keyword in style_keywords:
            safe_keyword = re.escape(keyword)
            partial_matches = style_desc[style_desc['style'].str.lower().str.contains(safe_keyword, na=False, regex=True)]
            if not partial_matches.empty:
                style_info = partial_matches['description'].values
                break

    artist_details = artist_info[0] if len(artist_info) > 0 else ""
    style_details = style_info[0] if len(style_info) > 0 else ""

    return f"{artist_details} This work exemplifies {style_details}."

# Function to generate description
def generate_description(image_path):
    image = Image.open(image_path).convert("RGB")
    image_resnet = data_transforms(image).unsqueeze(0).to(device)

    # Predict style and artist
    with torch.no_grad():
        outputs_style, outputs_artist = model(image_resnet)
        _, predicted_style_idx = torch.max(outputs_style, 1)
        _, predicted_artist_idx = torch.max(outputs_artist, 1)

    idx_to_style = {v: k for k, v in dataset.label_map_style.items()}
    idx_to_artist = {v: k for k, v in dataset.label_map_artist.items()}
    predicted_style = idx_to_style[predicted_style_idx.item()]
    predicted_artist = idx_to_artist[predicted_artist_idx.item()]

    # Enrich prompt
    enriched_prompt = enrich_prompt(predicted_artist, predicted_style)
    full_prompt = (
        f"This is an artwork created by {predicted_artist} in the style of {predicted_style}. {enriched_prompt} "
        "Describe its distinctive features, considering both the artist's techniques and the artistic style."
    )

    input_ids = tokenizer.encode(full_prompt, return_tensors="pt").to(device)
    output = model_gptneo.generate(
        input_ids=input_ids,
        max_length=250,
        num_return_sequences=1,
        temperature=0.7,
        top_p=0.9,
        repetition_penalty=1.2
    )

    description_text = tokenizer.decode(output[0], skip_special_tokens=True)
    return predicted_style, predicted_artist, description_text

# Gradio interface
def predict(image):
    style, artist, description = generate_description(image)
    return f"**Predicted Style**: {style}\n\n**Predicted Artist**: {artist}\n\n**Description**:\n{description}"

iface = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="filepath"),
    outputs="text",
    title="AI-Powered Artwork Recognition and Description",
    description="Upload an image of artwork to predict its style and artist, and generate a description."
)

if __name__ == "__main__":
    iface.launch()