Spaces:
Sleeping
Sleeping
File size: 6,343 Bytes
9d16cc3 d2cf9ed 9d16cc3 67bbe81 9d16cc3 67bbe81 9d16cc3 d2cf9ed 9d16cc3 15785c1 9d16cc3 15785c1 9d16cc3 15785c1 67bbe81 b7c2afa 67bbe81 b7c2afa 67bbe81 9943577 b7c2afa 9943577 67bbe81 15785c1 b7c2afa 9d16cc3 15785c1 9d16cc3 67bbe81 b7c2afa 9d16cc3 dc1a793 9d16cc3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import torch
from PIL import Image
from torchvision import transforms, models
from transformers import AutoModelForCausalLM, AutoTokenizer
import pandas as pd
from sentence_transformers import SentenceTransformer
import random
import urllib.parse
import torch.nn as nn
from sklearn.metrics import classification_report
from torch.optim.lr_scheduler import ReduceLROnPlateau
import gradio as gr
from io import BytesIO
# Device setup
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")
# Data transformation
data_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Load datasets for enriched prompts
dataset_desc = pd.read_csv("dataset_desc.csv", delimiter=';', usecols=['Artists', 'Style', 'Description'])
dataset_desc.columns = dataset_desc.columns.str.lower()
style_desc = pd.read_csv("style_desc.csv", delimiter=';')
style_desc.columns = style_desc.columns.str.lower()
# Function to enrich prompts with custom data
def enrich_prompt(artist, style):
artist_info = dataset_desc.loc[dataset_desc['artists'] == artist, 'description'].values
style_info = style_desc.loc[style_desc['style'] == style, 'description'].values
artist_details = artist_info[0] if len(artist_info) > 0 else "Details about the artist are not available."
style_details = style_info[0] if len(style_info) > 0 else "Details about the style are not available."
return f"{artist_details} This work exemplifies {style_details}."
# Custom dataset for ResNet18
class ArtDataset:
def __init__(self, csv_file):
self.annotations = pd.read_csv(csv_file)
self.train_data = self.annotations[self.annotations['subset'] == 'train']
self.test_data = self.annotations[self.annotations['subset'] == 'test']
self.label_map_style = {style: idx for idx, style in enumerate(self.annotations['genre'].unique())}
self.label_map_artist = {artist: idx for idx, artist in enumerate(self.annotations['artist'].unique())}
def get_style_and_artist_mappings(self):
return self.label_map_style, self.label_map_artist
def get_train_test_split(self):
return self.train_data, self.test_data
# DualOutputResNet model with Dropout
class DualOutputResNet(nn.Module):
def __init__(self, num_styles, num_artists, dropout_rate=0.5):
super(DualOutputResNet, self).__init__()
self.backbone = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
num_features = self.backbone.fc.in_features
self.backbone.fc = nn.Identity()
self.dropout = nn.Dropout(dropout_rate)
self.fc_style = nn.Linear(num_features, num_styles)
self.fc_artist = nn.Linear(num_features, num_artists)
def forward(self, x):
features = self.backbone(x)
features = self.dropout(features)
style_output = self.fc_style(features)
artist_output = self.fc_artist(features)
return style_output, artist_output
# Load dataset
csv_file = "cleaned_classes.csv"
dataset = ArtDataset(csv_file)
label_map_style, label_map_artist = dataset.get_style_and_artist_mappings()
train_data, test_data = dataset.get_train_test_split()
num_styles = len(label_map_style)
num_artists = len(label_map_artist)
# Model setup
model_resnet = DualOutputResNet(num_styles, num_artists).to(device)
optimizer = torch.optim.Adam(model_resnet.parameters(), lr=0.001, weight_decay=1e-5)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)
# Load SentenceTransformer model
clip_model = SentenceTransformer('sentence-transformers/clip-ViT-B-32-multilingual-v1').to(device)
# Load GPT-Neo and set padding token
model_name = "EleutherAI/gpt-neo-1.3B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token # Set pad_token to eos_token
model_gptneo = AutoModelForCausalLM.from_pretrained(model_name).to(device)
def generate_description(image):
image_resnet = data_transforms(image).unsqueeze(0).to(device)
model_resnet.eval()
with torch.no_grad():
outputs_style, outputs_artist = model_resnet(image_resnet)
_, predicted_style_idx = torch.max(outputs_style, 1)
_, predicted_artist_idx = torch.max(outputs_artist, 1)
idx_to_style = {v: k for k, v in label_map_style.items()}
idx_to_artist = {v: k for k, v in label_map_artist.items()}
predicted_style = idx_to_style[predicted_style_idx.item()]
predicted_artist = idx_to_artist[predicted_artist_idx.item()]
enriched_prompt = enrich_prompt(predicted_artist, predicted_style)
full_prompt = (
f"This is an artwork created by {predicted_artist} in the style of {predicted_style}. {enriched_prompt} "
"Describe its distinctive features, considering both the artist's techniques and the artistic style."
)
input_ids = tokenizer.encode(full_prompt, return_tensors="pt", padding=True).to(device)
attention_mask = input_ids != tokenizer.pad_token_id
output = model_gptneo.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_length=250,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.5,
do_sample=True,
pad_token_id=tokenizer.pad_token_id
)
description_text = tokenizer.decode(output[0], skip_special_tokens=True)
return predicted_style, predicted_artist, description_text
# Gradio interface
def gradio_interface(image):
if image is None:
return "No image provided. Please upload an image."
if isinstance(image, BytesIO):
image = Image.open(image).convert("RGB")
else:
image = Image.open(image).convert("RGB")
predicted_style, predicted_artist, description = generate_description(image)
return f"Predicted Style: {predicted_style}\nPredicted Artist: {predicted_artist}\n\nDescription:\n{description}"
iface = gr.Interface(
fn=gradio_interface,
inputs=gr.Image(type="filepath"),
outputs="text",
title="AI Artwork Analysis",
description="Upload an image to predict its artistic style and creator, and generate a detailed description."
)
if __name__ == "__main__":
iface.launch()
|