uruguayai commited on
Commit
7166f76
·
verified ·
1 Parent(s): 967b314

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -3
app.py CHANGED
@@ -60,7 +60,7 @@ unet = pipeline.unet
60
  def adjust_unet_input_layer(params):
61
  conv_in_weight = params['unet']['conv_in']['kernel']
62
  new_conv_in_weight = jnp.zeros((3, 3, 4, 320), dtype=jnp.float32)
63
- new_conv_in_weight = new_conv_in_weight.at[:, :, :4, :].set(conv_in_weight[:, :, :4, :])
64
  params['unet']['conv_in']['kernel'] = new_conv_in_weight
65
  return params
66
 
@@ -70,14 +70,17 @@ params = adjust_unet_input_layer(params)
70
  def preprocess_images(examples):
71
  def process_image(image):
72
  if isinstance(image, str):
 
 
73
  image = Image.open(image)
74
  if not isinstance(image, Image.Image):
75
- raise ValueError(f"Unexpected image type: {type(image)}")
76
  image = image.convert("RGB").resize((512, 512))
77
  image = np.array(image).astype(np.float32) / 255.0
78
  return image.transpose(2, 0, 1)
79
 
80
- return {"pixel_values": [process_image(img) for img in examples["image"]]}
 
81
 
82
  # Load dataset from Hugging Face
83
  dataset_name = "uruguayai/montevideo"
@@ -94,6 +97,7 @@ else:
94
  print("Processing dataset...")
95
  dataset = load_dataset(dataset_name)
96
  processed_dataset = dataset["train"].map(preprocess_images, batched=True, remove_columns=dataset["train"].column_names)
 
97
  with open(dataset_cache_file, 'wb') as f:
98
  pickle.dump(processed_dataset, f)
99
 
 
60
  def adjust_unet_input_layer(params):
61
  conv_in_weight = params['unet']['conv_in']['kernel']
62
  new_conv_in_weight = jnp.zeros((3, 3, 4, 320), dtype=jnp.float32)
63
+ new_conv_in_weight = new_conv_in_weight.at[:, :, :3, :].set(conv_in_weight[:, :, :3, :])
64
  params['unet']['conv_in']['kernel'] = new_conv_in_weight
65
  return params
66
 
 
70
  def preprocess_images(examples):
71
  def process_image(image):
72
  if isinstance(image, str):
73
+ if not image.lower().endswith('.jpg') and not image.lower().endswith('.jpeg'):
74
+ return None
75
  image = Image.open(image)
76
  if not isinstance(image, Image.Image):
77
+ return None
78
  image = image.convert("RGB").resize((512, 512))
79
  image = np.array(image).astype(np.float32) / 255.0
80
  return image.transpose(2, 0, 1)
81
 
82
+ processed = [process_image(img) for img in examples["image"]]
83
+ return {"pixel_values": [img for img in processed if img is not None]}
84
 
85
  # Load dataset from Hugging Face
86
  dataset_name = "uruguayai/montevideo"
 
97
  print("Processing dataset...")
98
  dataset = load_dataset(dataset_name)
99
  processed_dataset = dataset["train"].map(preprocess_images, batched=True, remove_columns=dataset["train"].column_names)
100
+ processed_dataset = processed_dataset.filter(lambda example: len(example['pixel_values']) > 0)
101
  with open(dataset_cache_file, 'wb') as f:
102
  pickle.dump(processed_dataset, f)
103