Dissertation2 / app.py
MusIre's picture
Update app.py
c16b43a verified
raw
history blame
6.29 kB
import gradio as gr
import torch
import clip
from PIL import Image
from torchvision import transforms, models
from transformers import AutoModelForCausalLM, AutoTokenizer
import pandas as pd
from torch.utils.data import Dataset
import torch.nn as nn
import urllib.parse
import re
# Set device
if torch.backends.mps.is_available():
device = torch.device("mps")
print("Utilizzo del dispositivo MPS")
else:
device = torch.device("cpu")
print("Utilizzo del dispositivo CPU")
# Dataset class
class ArtDataset(Dataset):
def __init__(self, csv_file, transform=None):
self.annotations = pd.read_csv(csv_file, delimiter=";")
self.transform = transform
self.label_map_style = {style: idx for idx, style in enumerate(self.annotations['genre'].unique())}
self.label_map_artist = {artist: idx for idx, artist in enumerate(self.annotations['artist'].unique())}
def __len__(self):
return len(self.annotations)
def __getitem__(self, idx):
img_path = self.annotations.iloc[idx]['filename']
safe_img_path = urllib.parse.quote(img_path, safe="/:")
try:
image = Image.open(safe_img_path).convert("RGB")
style_label = self.label_map_style[self.annotations.iloc[idx]['genre']]
artist_label = self.label_map_artist[self.annotations.iloc[idx]['artist']]
if self.transform:
image = self.transform(image)
return image, (style_label, artist_label)
except (FileNotFoundError, OSError):
return None, (None, None)
# Image transformations
data_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Load dataset
csv_file = "classes.csv"
dataset = ArtDataset(csv_file=csv_file, transform=data_transforms)
# Define model
class DualOutputResNet(nn.Module):
def __init__(self, num_styles, num_artists):
super(DualOutputResNet, self).__init__()
self.backbone = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
num_features = self.backbone.fc.in_features
self.backbone.fc = nn.Identity()
self.fc_style = nn.Linear(num_features, num_styles)
self.fc_artist = nn.Linear(num_features, num_artists)
def forward(self, x):
features = self.backbone(x)
style_output = self.fc_style(features)
artist_output = self.fc_artist(features)
return style_output, artist_output
# Load pre-trained model
num_styles = len(dataset.label_map_style)
num_artists = len(dataset.label_map_artist)
model = DualOutputResNet(num_styles, num_artists).to(device)
model.load_state_dict(torch.load("dual_output_resnet.pth", map_location=device))
model.eval()
# Load CLIP model
model_clip, preprocess_clip = clip.load("ViT-B/32", device=device)
model_clip.eval()
# Load GPT-Neo model
model_name = "EleutherAI/gpt-neo-1.3B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model_gptneo = AutoModelForCausalLM.from_pretrained(model_name).to(device)
#Load dataset
dataset_desc = pd.read_csv("dataset_desc.csv", delimiter=';', usecols=['Artists', 'Style', 'Description'])
dataset_desc.columns = dataset_desc.columns.str.lower()
style_desc = pd.read_csv("style_desc.csv", delimiter=';')
style_desc.columns = style_desc.columns.str.lower()
# Function to enrich prompt
def enrich_prompt(artist, style):
artist_info = dataset_desc.loc[dataset_desc['artists'].str.lower() == artist.lower(), 'description'].values
style_info = style_desc.loc[style_desc['style'].str.lower() == style.lower(), 'description'].values
if len(style_info) == 0:
style_keywords = style.lower().split()
for keyword in style_keywords:
safe_keyword = re.escape(keyword)
partial_matches = style_desc[style_desc['style'].str.lower().str.contains(safe_keyword, na=False, regex=True)]
if not partial_matches.empty:
style_info = partial_matches['description'].values
break
artist_details = artist_info[0] if len(artist_info) > 0 else ""
style_details = style_info[0] if len(style_info) > 0 else ""
return f"{artist_details} This work exemplifies {style_details}."
# Function to generate description
def generate_description(image_path):
image = Image.open(image_path).convert("RGB")
image_resnet = data_transforms(image).unsqueeze(0).to(device)
# Predict style and artist
with torch.no_grad():
outputs_style, outputs_artist = model(image_resnet)
_, predicted_style_idx = torch.max(outputs_style, 1)
_, predicted_artist_idx = torch.max(outputs_artist, 1)
idx_to_style = {v: k for k, v in dataset.label_map_style.items()}
idx_to_artist = {v: k for k, v in dataset.label_map_artist.items()}
predicted_style = idx_to_style[predicted_style_idx.item()]
predicted_artist = idx_to_artist[predicted_artist_idx.item()]
# Enrich prompt
enriched_prompt = enrich_prompt(predicted_artist, predicted_style)
full_prompt = (
f"This is an artwork created by {predicted_artist} in the style of {predicted_style}. {enriched_prompt} "
"Describe its distinctive features, considering both the artist's techniques and the artistic style."
)
input_ids = tokenizer.encode(full_prompt, return_tensors="pt").to(device)
output = model_gptneo.generate(
input_ids=input_ids,
max_length=250,
num_return_sequences=1,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.2
)
description_text = tokenizer.decode(output[0], skip_special_tokens=True)
return predicted_style, predicted_artist, description_text
# Gradio interface
def predict(image):
style, artist, description = generate_description(image)
return f"**Predicted Style**: {style}\n\n**Predicted Artist**: {artist}\n\n**Description**:\n{description}"
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="filepath"),
outputs="text",
title="AI-Powered Artwork Recognition and Description",
description="Upload an image of artwork to predict its style and artist, and generate a description."
)
if __name__ == "__main__":
iface.launch()