MusIre commited on
Commit
9d16cc3
·
verified ·
1 Parent(s): c073c7f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +132 -0
app.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import clip
3
+ from PIL import Image
4
+ from torchvision import transforms, models
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
6
+ import pandas as pd
7
+ import random
8
+ import urllib.parse
9
+ import torch.nn as nn
10
+ from sklearn.metrics import classification_report
11
+ from torch.optim.lr_scheduler import ReduceLROnPlateau
12
+ import gradio as gr
13
+
14
+ # Device setup
15
+ device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
16
+ print(f"Using device: {device}")
17
+
18
+ # Data transformation
19
+ data_transforms = transforms.Compose([
20
+ transforms.Resize((224, 224)),
21
+ transforms.ToTensor(),
22
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
23
+ ])
24
+
25
+ # Load datasets for enriched prompts
26
+ dataset_desc = pd.read_csv("dataset_desc.csv", delimiter=';', usecols=['Artists', 'Style', 'Description'])
27
+ dataset_desc.columns = dataset_desc.columns.str.lower()
28
+ style_desc = pd.read_csv("style_desc.csv", delimiter=';') # CSV containing style-specific descriptions
29
+ style_desc.columns = style_desc.columns.str.lower()
30
+
31
+ # Function to enrich prompts with custom data
32
+ def enrich_prompt(artist, style):
33
+ artist_info = dataset_desc.loc[dataset_desc['artists'] == artist, 'description'].values
34
+ style_info = style_desc.loc[style_desc['style'] == style, 'description'].values
35
+
36
+ artist_details = artist_info[0] if len(artist_info) > 0 else "Details about the artist are not available."
37
+ style_details = style_info[0] if len(style_info) > 0 else "Details about the style are not available."
38
+
39
+ return f"{artist_details} This work exemplifies {style_details}."
40
+
41
+ # Custom dataset for ResNet18
42
+ class ArtDataset:
43
+ def __init__(self, csv_file):
44
+ self.annotations = pd.read_csv(csv_file)
45
+ self.train_data = self.annotations[self.annotations['subset'] == 'train']
46
+ self.test_data = self.annotations[self.annotations['subset'] == 'test']
47
+ self.label_map_style = {style: idx for idx, style in enumerate(self.annotations['genre'].unique())}
48
+ self.label_map_artist = {artist: idx for idx, artist in enumerate(self.annotations['artist'].unique())}
49
+
50
+ def get_style_and_artist_mappings(self):
51
+ return self.label_map_style, self.label_map_artist
52
+
53
+ def get_train_test_split(self):
54
+ return self.train_data, self.test_data
55
+
56
+ # DualOutputResNet model with Dropout
57
+ class DualOutputResNet(nn.Module):
58
+ def __init__(self, num_styles, num_artists, dropout_rate=0.5):
59
+ super(DualOutputResNet, self).__init__()
60
+ self.backbone = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
61
+ num_features = self.backbone.fc.in_features
62
+ self.backbone.fc = nn.Identity()
63
+ self.dropout = nn.Dropout(dropout_rate)
64
+ self.fc_style = nn.Linear(num_features, num_styles)
65
+ self.fc_artist = nn.Linear(num_features, num_artists)
66
+
67
+ def forward(self, x):
68
+ features = self.backbone(x)
69
+ features = self.dropout(features)
70
+ style_output = self.fc_style(features)
71
+ artist_output = self.fc_artist(features)
72
+ return style_output, artist_output
73
+
74
+ # Load dataset
75
+ csv_file = "cleaned_classes.csv"
76
+ dataset = ArtDataset(csv_file)
77
+ label_map_style, label_map_artist = dataset.get_style_and_artist_mappings()
78
+ train_data, test_data = dataset.get_train_test_split()
79
+ num_styles = len(label_map_style)
80
+ num_artists = len(label_map_artist)
81
+
82
+ # Model setup
83
+ model_resnet = DualOutputResNet(num_styles, num_artists).to(device)
84
+ optimizer = torch.optim.Adam(model_resnet.parameters(), lr=0.001, weight_decay=1e-5)
85
+ scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)
86
+
87
+ # Load GPT-Neo and CLIP
88
+ model_clip, preprocess_clip = clip.load("ViT-B/32", device=device)
89
+ model_clip.eval()
90
+
91
+ model_name = "EleutherAI/gpt-neo-1.3B"
92
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
93
+ model_gptneo = AutoModelForCausalLM.from_pretrained(model_name).to(device)
94
+
95
+ # Generate prediction using ResNet and CLIP
96
+ def predict(image_path):
97
+ image = Image.open(image_path).convert("RGB")
98
+ image_tensor = data_transforms(image).unsqueeze(0).to(device)
99
+
100
+ # Predict with ResNet
101
+ style_logits, artist_logits = model_resnet(image_tensor)
102
+ style_idx = torch.argmax(style_logits, dim=1).item()
103
+ artist_idx = torch.argmax(artist_logits, dim=1).item()
104
+
105
+ predicted_style = list(label_map_style.keys())[list(label_map_style.values()).index(style_idx)]
106
+ predicted_artist = list(label_map_artist.keys())[list(label_map_artist.values()).index(artist_idx)]
107
+
108
+ # Enrich prompt with additional information
109
+ prompt = enrich_prompt(predicted_artist, predicted_style)
110
+
111
+ # Generate text description using GPT-Neo
112
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
113
+ output = model_gptneo.generate(input_ids, max_length=350, num_return_sequences=1)
114
+ description = tokenizer.decode(output[0], skip_special_tokens=True)
115
+
116
+ return predicted_style, predicted_artist, description
117
+
118
+ # Gradio interface
119
+ def gradio_interface(image):
120
+ predicted_style, predicted_artist, description = predict(image)
121
+ return f"Predicted Style: {predicted_style}\nPredicted Artist: {predicted_artist}\n\nDescription:\n{description}"
122
+
123
+ iface = gr.Interface(
124
+ fn=gradio_interface,
125
+ inputs=gr.Image(type="filepath"),
126
+ outputs="text",
127
+ title="AI Artwork Analysis",
128
+ description="Upload an image to predict its artistic style and creator, and generate a detailed description."
129
+ )
130
+
131
+ if __name__ == "__main__":
132
+ iface.launch()