Spaces:
Sleeping
Sleeping
File size: 7,241 Bytes
ae0af75 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 |
import torch
import torch.nn as nn
import cv2
import imageio
import os
import subprocess
from config.core import config
from models.generator import Generator
from torchvision.utils import save_image
def save_some_examples(generator_model, batch, epoch, folder_path=config.PATH_OUTPUT, num_images=15):
"""
Save some examples of the generator's output.
Parameters:
generator_model (nn.Module): The generator model.
batch (tuple): The batch of input and target images as a tuple of tensors.
epoch (int): The current epoch.
folder_path (str): The folder path to save the examples to. Defaults to config.PATH_OUTPUT.
num_images (int): The number of images to save. Defaults to 15.
"""
# Ensure the folder exists
os.makedirs(folder_path, exist_ok=True)
x, y = batch # Unpack the batch
# Limit the number of images to the specified num_images
x = x[:num_images]
y = y[:num_images]
generator_model.eval()
with torch.inference_mode():
y_fake = generator_model(x)
y_fake = y_fake * 0.5 + 0.5 # Remove normalization by tanh
# Create 3x5 grid for generated images
save_image(y_fake, folder_path + f"/y_gen_{epoch}.png", nrow=5) # Save Generated Image
# Create 3x5 grid for input images
save_image(x * 0.5 + 0.5, folder_path + f"/input_{epoch}.png", nrow=5) # Save Real Image
generator_model.train()
def update_version_kaggle_dataset():
# Make Metadata json
subprocess.run(['kaggle', 'datasets', 'init'], check=True)
# Write new metadata
with open('/kaggle/working/dataset-metadata.json', 'w') as json_fid:
json_fid.write(f'{{\n "title": "Update Logs Pix2Pix",\n "id": "muhammadnaufal/pix2pix",\n "licenses": [{{"name": "CC0-1.0"}}]}}')
# Push new version
subprocess.run(['kaggle', 'datasets', 'version', '-m', 'Updated Dataset', '--quiet', '--dir-mode', 'tar'], check=True)
def init_generator_model():
"""
Initializes and returns the Generator model.
Args:
None.
Returns:
Generator: The initialized Generator model.
"""
model = Generator(
in_channels=config.IMAGE_CHANNELS,
features=config.FEATURE_GENERATOR,
)
return model
def load_model_weights(checkpoint_path, model, device, prefix):
"""
Load specific weights from a PyTorch Lightning checkpoint into a model.
Parameters:
checkpoint_path (str): Path to the checkpoint file.
model (torch.nn.Module): The model instance to load weights into.
prefix (str): The prefix in the checkpoint's state_dict keys to filter by and remove.
Returns:
model (torch.nn.Module): The model with loaded weights.
"""
# Load the checkpoint
checkpoint = torch.load(checkpoint_path, map_location=device)
# Extract and modify the state_dict keys to match the model's keys
model_weights = {k.replace(f"{prefix}.", ""): v for k, v in checkpoint["state_dict"].items() if k.startswith(f"{prefix}.")}
# Load the weights into the model
model.load_state_dict(model_weights)
return model
def initialize_weights(model):
"""
Initializes the weights of a model using a normal distribution.
Args:
model: The model to be initialized.
Returns:
None
"""
for m in model.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d, nn.InstanceNorm2d)):
nn.init.normal_(m.weight.data, 0.0, 0.02)
def create_video(image_folder, video_name, fps, appearance_duration=None):
"""
Creates a video from a sequence of images with customizable appearance duration.
Args:
image_folder (str): The path to the folder containing the images.
video_name (str): The name of the output video file.
fps (int): The frames per second of the video.
appearance_duration (int, optional): The desired appearance duration for each image in milliseconds.
If None, the default duration based on frame rate is used.
Example:
image_folder = '/path/to/image/folder' \n
video_name = 'output_video.mp4' \n
fps = 12 \n
appearance_duration = 200 # Appearance duration of 200ms for each image \n
create_video(image_folder, video_name, fps, appearance_duration)
"""
# Get a list of all image files in the folder
image_files = [f for f in os.listdir(image_folder) if f.endswith('.png')]
# Sort the image files based on the step number
image_files = sorted(image_files, key=lambda x: int(x.split('-')[1].split('.')[0]))
# Load the first image to get the video size
image = cv2.imread(os.path.join(image_folder, image_files[0]))
height, width, layers = image.shape
# Create a VideoWriter object
fourcc = cv2.VideoWriter_fourcc(*'mp4v') # Specify the video codec
video = cv2.VideoWriter(video_name, fourcc, fps, (width, height))
# Write each image to the video with customizable appearance duration
for image_file in image_files:
image = cv2.imread(os.path.join(image_folder, image_file))
video.write(image)
if appearance_duration is not None:
# Calculate the number of frames for the desired appearance duration
num_frames = appearance_duration * fps // 1000
for _ in range(num_frames):
video.write(image)
# Release the video writer
video.release()
def create_gif(image_folder, gif_name, fps, appearance_duration=None):
"""
Creates a GIF from a sequence of images sorted by step number, with customizable appearance duration.
Args:
image_folder (str): The path to the folder containing the images.
gif_name (str): The name of the output GIF file.
fps (int): The frames per second of the GIF.
appearance_duration (int, optional): The desired appearance duration for each image in milliseconds.
If None, the default duration based on frame rate is used.
Example:
image_folder = '/path/to/image/folder'
gif_name = 'output_animation.gif'
fps = 12
appearance_duration = 300 # Appearance duration of 300ms for each image
create_gif(image_folder, gif_name, fps, appearance_duration)
"""
# Get a list of all image files in the folder
image_files = [f for f in os.listdir(image_folder) if f.endswith('.png')]
# Sort the image files based on the step number
image_files = sorted(image_files, key=lambda x: int(x.split('-')[1].split('.')[0]))
# Load the images into a list
images = []
for file in image_files:
images.append(imageio.imread(os.path.join(image_folder, file)))
# Create a list to store the repeated images
repeated_images = []
# Repeat each image for the desired duration
if appearance_duration is not None:
for image in images:
repeated_images.extend([image] * (appearance_duration * fps // 1000))
else:
repeated_images = images # Default appearance duration (based on fps)
# Save the repeated images as a GIF
imageio.mimsave(gif_name, repeated_images, fps=fps)
|