Update app.py
Browse files
app.py
CHANGED
@@ -60,7 +60,7 @@ unet = pipeline.unet
|
|
60 |
def adjust_unet_input_layer(params):
|
61 |
conv_in_weight = params['unet']['conv_in']['kernel']
|
62 |
new_conv_in_weight = jnp.zeros((3, 3, 4, 320), dtype=jnp.float32)
|
63 |
-
new_conv_in_weight = new_conv_in_weight.at[:, :, :
|
64 |
params['unet']['conv_in']['kernel'] = new_conv_in_weight
|
65 |
return params
|
66 |
|
@@ -70,14 +70,17 @@ params = adjust_unet_input_layer(params)
|
|
70 |
def preprocess_images(examples):
|
71 |
def process_image(image):
|
72 |
if isinstance(image, str):
|
|
|
|
|
73 |
image = Image.open(image)
|
74 |
if not isinstance(image, Image.Image):
|
75 |
-
|
76 |
image = image.convert("RGB").resize((512, 512))
|
77 |
image = np.array(image).astype(np.float32) / 255.0
|
78 |
return image.transpose(2, 0, 1)
|
79 |
|
80 |
-
|
|
|
81 |
|
82 |
# Load dataset from Hugging Face
|
83 |
dataset_name = "uruguayai/montevideo"
|
@@ -94,6 +97,7 @@ else:
|
|
94 |
print("Processing dataset...")
|
95 |
dataset = load_dataset(dataset_name)
|
96 |
processed_dataset = dataset["train"].map(preprocess_images, batched=True, remove_columns=dataset["train"].column_names)
|
|
|
97 |
with open(dataset_cache_file, 'wb') as f:
|
98 |
pickle.dump(processed_dataset, f)
|
99 |
|
|
|
60 |
def adjust_unet_input_layer(params):
|
61 |
conv_in_weight = params['unet']['conv_in']['kernel']
|
62 |
new_conv_in_weight = jnp.zeros((3, 3, 4, 320), dtype=jnp.float32)
|
63 |
+
new_conv_in_weight = new_conv_in_weight.at[:, :, :3, :].set(conv_in_weight[:, :, :3, :])
|
64 |
params['unet']['conv_in']['kernel'] = new_conv_in_weight
|
65 |
return params
|
66 |
|
|
|
70 |
def preprocess_images(examples):
|
71 |
def process_image(image):
|
72 |
if isinstance(image, str):
|
73 |
+
if not image.lower().endswith('.jpg') and not image.lower().endswith('.jpeg'):
|
74 |
+
return None
|
75 |
image = Image.open(image)
|
76 |
if not isinstance(image, Image.Image):
|
77 |
+
return None
|
78 |
image = image.convert("RGB").resize((512, 512))
|
79 |
image = np.array(image).astype(np.float32) / 255.0
|
80 |
return image.transpose(2, 0, 1)
|
81 |
|
82 |
+
processed = [process_image(img) for img in examples["image"]]
|
83 |
+
return {"pixel_values": [img for img in processed if img is not None]}
|
84 |
|
85 |
# Load dataset from Hugging Face
|
86 |
dataset_name = "uruguayai/montevideo"
|
|
|
97 |
print("Processing dataset...")
|
98 |
dataset = load_dataset(dataset_name)
|
99 |
processed_dataset = dataset["train"].map(preprocess_images, batched=True, remove_columns=dataset["train"].column_names)
|
100 |
+
processed_dataset = processed_dataset.filter(lambda example: len(example['pixel_values']) > 0)
|
101 |
with open(dataset_cache_file, 'wb') as f:
|
102 |
pickle.dump(processed_dataset, f)
|
103 |
|