MusIre commited on
Commit
67bbe81
·
verified ·
1 Parent(s): bf1dc6a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -6
app.py CHANGED
@@ -10,6 +10,7 @@ import torch.nn as nn
10
  from sklearn.metrics import classification_report
11
  from torch.optim.lr_scheduler import ReduceLROnPlateau
12
  import gradio as gr
 
13
 
14
  # Device setup
15
  device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
@@ -25,7 +26,7 @@ data_transforms = transforms.Compose([
25
  # Load datasets for enriched prompts
26
  dataset_desc = pd.read_csv("dataset_desc.csv", delimiter=';', usecols=['Artists', 'Style', 'Description'])
27
  dataset_desc.columns = dataset_desc.columns.str.lower()
28
- style_desc = pd.read_csv("style_desc.csv", delimiter=';') # CSV containing style-specific descriptions
29
  style_desc.columns = style_desc.columns.str.lower()
30
 
31
  # Function to enrich prompts with custom data
@@ -91,8 +92,7 @@ model_name = "EleutherAI/gpt-neo-1.3B"
91
  tokenizer = AutoTokenizer.from_pretrained(model_name)
92
  model_gptneo = AutoModelForCausalLM.from_pretrained(model_name).to(device)
93
 
94
- def generate_description(image_path):
95
- image = Image.open(image_path).convert("RGB")
96
  image_resnet = data_transforms(image).unsqueeze(0).to(device)
97
 
98
  model_resnet.eval()
@@ -112,13 +112,18 @@ def generate_description(image_path):
112
  "Describe its distinctive features, considering both the artist's techniques and the artistic style."
113
  )
114
 
115
- input_ids = tokenizer.encode(full_prompt, return_tensors="pt").to(device)
 
 
116
  output = model_gptneo.generate(
117
  input_ids=input_ids,
 
118
  max_length=300,
119
  temperature=0.7,
120
  top_p=0.9,
121
- repetition_penalty=1.2
 
 
122
  )
123
 
124
  description_text = tokenizer.decode(output[0], skip_special_tokens=True)
@@ -127,12 +132,20 @@ def generate_description(image_path):
127
 
128
  # Gradio interface
129
  def gradio_interface(image):
 
 
 
 
 
 
 
 
130
  predicted_style, predicted_artist, description = generate_description(image)
131
  return f"Predicted Style: {predicted_style}\nPredicted Artist: {predicted_artist}\n\nDescription:\n{description}"
132
 
133
  iface = gr.Interface(
134
  fn=gradio_interface,
135
- inputs=gr.Image(type="filepath"),
136
  outputs="text",
137
  title="AI Artwork Analysis",
138
  description="Upload an image to predict its artistic style and creator, and generate a detailed description."
 
10
  from sklearn.metrics import classification_report
11
  from torch.optim.lr_scheduler import ReduceLROnPlateau
12
  import gradio as gr
13
+ from io import BytesIO
14
 
15
  # Device setup
16
  device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
 
26
  # Load datasets for enriched prompts
27
  dataset_desc = pd.read_csv("dataset_desc.csv", delimiter=';', usecols=['Artists', 'Style', 'Description'])
28
  dataset_desc.columns = dataset_desc.columns.str.lower()
29
+ style_desc = pd.read_csv("style_desc.csv", delimiter=';')
30
  style_desc.columns = style_desc.columns.str.lower()
31
 
32
  # Function to enrich prompts with custom data
 
92
  tokenizer = AutoTokenizer.from_pretrained(model_name)
93
  model_gptneo = AutoModelForCausalLM.from_pretrained(model_name).to(device)
94
 
95
+ def generate_description(image):
 
96
  image_resnet = data_transforms(image).unsqueeze(0).to(device)
97
 
98
  model_resnet.eval()
 
112
  "Describe its distinctive features, considering both the artist's techniques and the artistic style."
113
  )
114
 
115
+ input_ids = tokenizer.encode(full_prompt, return_tensors="pt", padding=True).to(device)
116
+ attention_mask = input_ids != tokenizer.pad_token_id
117
+
118
  output = model_gptneo.generate(
119
  input_ids=input_ids,
120
+ attention_mask=attention_mask,
121
  max_length=300,
122
  temperature=0.7,
123
  top_p=0.9,
124
+ repetition_penalty=1.2,
125
+ do_sample=True,
126
+ pad_token_id=tokenizer.eos_token_id
127
  )
128
 
129
  description_text = tokenizer.decode(output[0], skip_special_tokens=True)
 
132
 
133
  # Gradio interface
134
  def gradio_interface(image):
135
+ if image is None:
136
+ return "No image provided. Please upload an image."
137
+
138
+ if isinstance(image, BytesIO):
139
+ image = Image.open(image).convert("RGB")
140
+ else:
141
+ image = Image.open(image).convert("RGB")
142
+
143
  predicted_style, predicted_artist, description = generate_description(image)
144
  return f"Predicted Style: {predicted_style}\nPredicted Artist: {predicted_artist}\n\nDescription:\n{description}"
145
 
146
  iface = gr.Interface(
147
  fn=gradio_interface,
148
+ inputs=gr.Image(type="file"),
149
  outputs="text",
150
  title="AI Artwork Analysis",
151
  description="Upload an image to predict its artistic style and creator, and generate a detailed description."