File size: 6,343 Bytes
9d16cc3
 
 
 
 
d2cf9ed
9d16cc3
 
 
 
 
 
67bbe81
9d16cc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67bbe81
9d16cc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2cf9ed
 
9d16cc3
15785c1
9d16cc3
 
15785c1
 
9d16cc3
 
15785c1
67bbe81
b7c2afa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67bbe81
 
 
b7c2afa
 
67bbe81
9943577
b7c2afa
 
9943577
67bbe81
15785c1
b7c2afa
 
 
 
 
9d16cc3
15785c1
9d16cc3
 
67bbe81
 
 
 
 
 
 
 
b7c2afa
9d16cc3
 
 
 
dc1a793
9d16cc3
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import torch
from PIL import Image
from torchvision import transforms, models
from transformers import AutoModelForCausalLM, AutoTokenizer
import pandas as pd
from sentence_transformers import SentenceTransformer
import random
import urllib.parse
import torch.nn as nn
from sklearn.metrics import classification_report
from torch.optim.lr_scheduler import ReduceLROnPlateau
import gradio as gr
from io import BytesIO

# Device setup
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

# Data transformation
data_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Load datasets for enriched prompts
dataset_desc = pd.read_csv("dataset_desc.csv", delimiter=';', usecols=['Artists', 'Style', 'Description'])
dataset_desc.columns = dataset_desc.columns.str.lower()
style_desc = pd.read_csv("style_desc.csv", delimiter=';')
style_desc.columns = style_desc.columns.str.lower()

# Function to enrich prompts with custom data
def enrich_prompt(artist, style):
    artist_info = dataset_desc.loc[dataset_desc['artists'] == artist, 'description'].values
    style_info = style_desc.loc[style_desc['style'] == style, 'description'].values

    artist_details = artist_info[0] if len(artist_info) > 0 else "Details about the artist are not available."
    style_details = style_info[0] if len(style_info) > 0 else "Details about the style are not available."

    return f"{artist_details} This work exemplifies {style_details}."

# Custom dataset for ResNet18
class ArtDataset:
    def __init__(self, csv_file):
        self.annotations = pd.read_csv(csv_file)
        self.train_data = self.annotations[self.annotations['subset'] == 'train']
        self.test_data = self.annotations[self.annotations['subset'] == 'test']
        self.label_map_style = {style: idx for idx, style in enumerate(self.annotations['genre'].unique())}
        self.label_map_artist = {artist: idx for idx, artist in enumerate(self.annotations['artist'].unique())}

    def get_style_and_artist_mappings(self):
        return self.label_map_style, self.label_map_artist

    def get_train_test_split(self):
        return self.train_data, self.test_data

# DualOutputResNet model with Dropout
class DualOutputResNet(nn.Module):
    def __init__(self, num_styles, num_artists, dropout_rate=0.5):
        super(DualOutputResNet, self).__init__()
        self.backbone = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        num_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()
        self.dropout = nn.Dropout(dropout_rate)
        self.fc_style = nn.Linear(num_features, num_styles)
        self.fc_artist = nn.Linear(num_features, num_artists)

    def forward(self, x):
        features = self.backbone(x)
        features = self.dropout(features)
        style_output = self.fc_style(features)
        artist_output = self.fc_artist(features)
        return style_output, artist_output

# Load dataset
csv_file = "cleaned_classes.csv"
dataset = ArtDataset(csv_file)
label_map_style, label_map_artist = dataset.get_style_and_artist_mappings()
train_data, test_data = dataset.get_train_test_split()
num_styles = len(label_map_style)
num_artists = len(label_map_artist)

# Model setup
model_resnet = DualOutputResNet(num_styles, num_artists).to(device)
optimizer = torch.optim.Adam(model_resnet.parameters(), lr=0.001, weight_decay=1e-5)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)

# Load SentenceTransformer model
clip_model = SentenceTransformer('sentence-transformers/clip-ViT-B-32-multilingual-v1').to(device)

# Load GPT-Neo and set padding token
model_name = "EleutherAI/gpt-neo-1.3B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token  # Set pad_token to eos_token
model_gptneo = AutoModelForCausalLM.from_pretrained(model_name).to(device)


def generate_description(image):
    image_resnet = data_transforms(image).unsqueeze(0).to(device)

    model_resnet.eval()
    with torch.no_grad():
        outputs_style, outputs_artist = model_resnet(image_resnet)
        _, predicted_style_idx = torch.max(outputs_style, 1)
        _, predicted_artist_idx = torch.max(outputs_artist, 1)

    idx_to_style = {v: k for k, v in label_map_style.items()}
    idx_to_artist = {v: k for k, v in label_map_artist.items()}
    predicted_style = idx_to_style[predicted_style_idx.item()]
    predicted_artist = idx_to_artist[predicted_artist_idx.item()]

    enriched_prompt = enrich_prompt(predicted_artist, predicted_style)
    full_prompt = (
        f"This is an artwork created by {predicted_artist} in the style of {predicted_style}. {enriched_prompt} "
        "Describe its distinctive features, considering both the artist's techniques and the artistic style."
    )

    input_ids = tokenizer.encode(full_prompt, return_tensors="pt", padding=True).to(device)
    attention_mask = input_ids != tokenizer.pad_token_id

    output = model_gptneo.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        max_length=250,
        temperature=0.7,
        top_p=0.9,
        repetition_penalty=1.5,
        do_sample=True,
        pad_token_id=tokenizer.pad_token_id
    )

    description_text = tokenizer.decode(output[0], skip_special_tokens=True)

    return predicted_style, predicted_artist, description_text


# Gradio interface
def gradio_interface(image):
    if image is None:
        return "No image provided. Please upload an image."

    if isinstance(image, BytesIO):
        image = Image.open(image).convert("RGB")
    else:
        image = Image.open(image).convert("RGB")

    predicted_style, predicted_artist, description = generate_description(image)
    return f"Predicted Style: {predicted_style}\nPredicted Artist: {predicted_artist}\n\nDescription:\n{description}"

iface = gr.Interface(
    fn=gradio_interface,
    inputs=gr.Image(type="filepath"),
    outputs="text",
    title="AI Artwork Analysis",
    description="Upload an image to predict its artistic style and creator, and generate a detailed description."
)

if __name__ == "__main__":
    iface.launch()