MusIre commited on
Commit
43146c8
·
verified ·
1 Parent(s): 3ba4861

Create artworksApp.py

Browse files
Files changed (1) hide show
  1. artworksApp.py +157 -0
artworksApp.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import clip
4
+ from PIL import Image
5
+ from torchvision import transforms, models
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer
7
+ import pandas as pd
8
+ from torch.utils.data import Dataset
9
+ import torch.nn as nn
10
+ import urllib.parse
11
+ import re
12
+
13
+ # Set device
14
+ if torch.backends.mps.is_available():
15
+ device = torch.device("mps")
16
+ print("Utilizzo del dispositivo MPS")
17
+ else:
18
+ device = torch.device("cpu")
19
+ print("Utilizzo del dispositivo CPU")
20
+
21
+ # Dataset class
22
+ class ArtDataset(Dataset):
23
+ def __init__(self, csv_file, transform=None):
24
+ self.annotations = pd.read_csv(csv_file)
25
+ self.transform = transform
26
+ self.label_map_style = {style: idx for idx, style in enumerate(self.annotations['genre'].unique())}
27
+ self.label_map_artist = {artist: idx for idx, artist in enumerate(self.annotations['artist'].unique())}
28
+
29
+ def __len__(self):
30
+ return len(self.annotations)
31
+
32
+ def __getitem__(self, idx):
33
+ img_path = self.annotations.iloc[idx]['filename']
34
+ safe_img_path = urllib.parse.quote(img_path, safe="/:")
35
+ try:
36
+ image = Image.open(safe_img_path).convert("RGB")
37
+ style_label = self.label_map_style[self.annotations.iloc[idx]['genre']]
38
+ artist_label = self.label_map_artist[self.annotations.iloc[idx]['artist']]
39
+ if self.transform:
40
+ image = self.transform(image)
41
+ return image, (style_label, artist_label)
42
+ except (FileNotFoundError, OSError):
43
+ return None, (None, None)
44
+
45
+ # Image transformations
46
+ data_transforms = transforms.Compose([
47
+ transforms.Resize((224, 224)),
48
+ transforms.ToTensor(),
49
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
50
+ ])
51
+
52
+ # Load dataset
53
+ csv_file = "classes.csv"
54
+ dataset = ArtDataset(csv_file=csv_file, transform=data_transforms)
55
+
56
+ # Define model
57
+ class DualOutputResNet(nn.Module):
58
+ def __init__(self, num_styles, num_artists):
59
+ super(DualOutputResNet, self).__init__()
60
+ self.backbone = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
61
+ num_features = self.backbone.fc.in_features
62
+ self.backbone.fc = nn.Identity()
63
+ self.fc_style = nn.Linear(num_features, num_styles)
64
+ self.fc_artist = nn.Linear(num_features, num_artists)
65
+
66
+ def forward(self, x):
67
+ features = self.backbone(x)
68
+ style_output = self.fc_style(features)
69
+ artist_output = self.fc_artist(features)
70
+ return style_output, artist_output
71
+
72
+ # Load pre-trained model
73
+ num_styles = len(dataset.label_map_style)
74
+ num_artists = len(dataset.label_map_artist)
75
+ model = DualOutputResNet(num_styles, num_artists).to(device)
76
+ model.load_state_dict(torch.load("dual_output_resnet.pth", map_location=device))
77
+ model.eval()
78
+
79
+ # Load CLIP model
80
+ model_clip, preprocess_clip = clip.load("ViT-B/32", device=device)
81
+ model_clip.eval()
82
+
83
+ # Load GPT-Neo model
84
+ model_name = "EleutherAI/gpt-neo-1.3B"
85
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
86
+ model_gptneo = AutoModelForCausalLM.from_pretrained(model_name).to(device)
87
+
88
+ # Function to enrich prompt
89
+ def enrich_prompt(artist, style):
90
+ artist_info = dataset_desc.loc[dataset_desc['artists'].str.lower() == artist.lower(), 'description'].values
91
+ style_info = style_desc.loc[style_desc['style'].str.lower() == style.lower(), 'description'].values
92
+
93
+ if len(style_info) == 0:
94
+ style_keywords = style.lower().split()
95
+ for keyword in style_keywords:
96
+ safe_keyword = re.escape(keyword)
97
+ partial_matches = style_desc[style_desc['style'].str.lower().str.contains(safe_keyword, na=False, regex=True)]
98
+ if not partial_matches.empty:
99
+ style_info = partial_matches['description'].values
100
+ break
101
+
102
+ artist_details = artist_info[0] if len(artist_info) > 0 else ""
103
+ style_details = style_info[0] if len(style_info) > 0 else ""
104
+
105
+ return f"{artist_details} This work exemplifies {style_details}."
106
+
107
+ # Function to generate description
108
+ def generate_description(image_path):
109
+ image = Image.open(image_path).convert("RGB")
110
+ image_resnet = data_transforms(image).unsqueeze(0).to(device)
111
+
112
+ # Predict style and artist
113
+ with torch.no_grad():
114
+ outputs_style, outputs_artist = model(image_resnet)
115
+ _, predicted_style_idx = torch.max(outputs_style, 1)
116
+ _, predicted_artist_idx = torch.max(outputs_artist, 1)
117
+
118
+ idx_to_style = {v: k for k, v in dataset.label_map_style.items()}
119
+ idx_to_artist = {v: k for k, v in dataset.label_map_artist.items()}
120
+ predicted_style = idx_to_style[predicted_style_idx.item()]
121
+ predicted_artist = idx_to_artist[predicted_artist_idx.item()]
122
+
123
+ # Enrich prompt
124
+ enriched_prompt = enrich_prompt(predicted_artist, predicted_style)
125
+ full_prompt = (
126
+ f"This is an artwork created by {predicted_artist} in the style of {predicted_style}. {enriched_prompt} "
127
+ "Describe its distinctive features, considering both the artist's techniques and the artistic style."
128
+ )
129
+
130
+ input_ids = tokenizer.encode(full_prompt, return_tensors="pt").to(device)
131
+ output = model_gptneo.generate(
132
+ input_ids=input_ids,
133
+ max_length=350,
134
+ num_return_sequences=1,
135
+ temperature=0.7,
136
+ top_p=0.9,
137
+ repetition_penalty=1.2
138
+ )
139
+
140
+ description_text = tokenizer.decode(output[0], skip_special_tokens=True)
141
+ return predicted_style, predicted_artist, description_text
142
+
143
+ # Gradio interface
144
+ def predict(image):
145
+ style, artist, description = generate_description(image)
146
+ return f"**Predicted Style**: {style}\n\n**Predicted Artist**: {artist}\n\n**Description**:\n{description}"
147
+
148
+ iface = gr.Interface(
149
+ fn=predict,
150
+ inputs=gr.Image(type="file"),
151
+ outputs="text",
152
+ title="AI-Powered Artwork Recognition and Description",
153
+ description="Upload an image of artwork to predict its style and artist, and generate a description."
154
+ )
155
+
156
+ if __name__ == "__main__":
157
+ iface.launch()