Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import cv2 | |
import imageio | |
import os | |
import subprocess | |
from config.core import config | |
from models.generator import Generator | |
from torchvision.utils import save_image | |
def save_some_examples(generator_model, batch, epoch, folder_path=config.PATH_OUTPUT, num_images=15): | |
""" | |
Save some examples of the generator's output. | |
Parameters: | |
generator_model (nn.Module): The generator model. | |
batch (tuple): The batch of input and target images as a tuple of tensors. | |
epoch (int): The current epoch. | |
folder_path (str): The folder path to save the examples to. Defaults to config.PATH_OUTPUT. | |
num_images (int): The number of images to save. Defaults to 15. | |
""" | |
# Ensure the folder exists | |
os.makedirs(folder_path, exist_ok=True) | |
x, y = batch # Unpack the batch | |
# Limit the number of images to the specified num_images | |
x = x[:num_images] | |
y = y[:num_images] | |
generator_model.eval() | |
with torch.inference_mode(): | |
y_fake = generator_model(x) | |
y_fake = y_fake * 0.5 + 0.5 # Remove normalization by tanh | |
# Create 3x5 grid for generated images | |
save_image(y_fake, folder_path + f"/y_gen_{epoch}.png", nrow=5) # Save Generated Image | |
# Create 3x5 grid for input images | |
save_image(x * 0.5 + 0.5, folder_path + f"/input_{epoch}.png", nrow=5) # Save Real Image | |
generator_model.train() | |
def update_version_kaggle_dataset(): | |
# Make Metadata json | |
subprocess.run(['kaggle', 'datasets', 'init'], check=True) | |
# Write new metadata | |
with open('/kaggle/working/dataset-metadata.json', 'w') as json_fid: | |
json_fid.write(f'{{\n "title": "Update Logs Pix2Pix",\n "id": "muhammadnaufal/pix2pix",\n "licenses": [{{"name": "CC0-1.0"}}]}}') | |
# Push new version | |
subprocess.run(['kaggle', 'datasets', 'version', '-m', 'Updated Dataset', '--quiet', '--dir-mode', 'tar'], check=True) | |
def init_generator_model(): | |
""" | |
Initializes and returns the Generator model. | |
Args: | |
None. | |
Returns: | |
Generator: The initialized Generator model. | |
""" | |
model = Generator( | |
in_channels=config.IMAGE_CHANNELS, | |
features=config.FEATURE_GENERATOR, | |
) | |
return model | |
def load_model_weights(checkpoint_path, model, device, prefix): | |
""" | |
Load specific weights from a PyTorch Lightning checkpoint into a model. | |
Parameters: | |
checkpoint_path (str): Path to the checkpoint file. | |
model (torch.nn.Module): The model instance to load weights into. | |
prefix (str): The prefix in the checkpoint's state_dict keys to filter by and remove. | |
Returns: | |
model (torch.nn.Module): The model with loaded weights. | |
""" | |
# Load the checkpoint | |
checkpoint = torch.load(checkpoint_path, map_location=device) | |
# Extract and modify the state_dict keys to match the model's keys | |
model_weights = {k.replace(f"{prefix}.", ""): v for k, v in checkpoint["state_dict"].items() if k.startswith(f"{prefix}.")} | |
# Load the weights into the model | |
model.load_state_dict(model_weights) | |
return model | |
def initialize_weights(model): | |
""" | |
Initializes the weights of a model using a normal distribution. | |
Args: | |
model: The model to be initialized. | |
Returns: | |
None | |
""" | |
for m in model.modules(): | |
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d, nn.InstanceNorm2d)): | |
nn.init.normal_(m.weight.data, 0.0, 0.02) | |
def create_video(image_folder, video_name, fps, appearance_duration=None): | |
""" | |
Creates a video from a sequence of images with customizable appearance duration. | |
Args: | |
image_folder (str): The path to the folder containing the images. | |
video_name (str): The name of the output video file. | |
fps (int): The frames per second of the video. | |
appearance_duration (int, optional): The desired appearance duration for each image in milliseconds. | |
If None, the default duration based on frame rate is used. | |
Example: | |
image_folder = '/path/to/image/folder' \n | |
video_name = 'output_video.mp4' \n | |
fps = 12 \n | |
appearance_duration = 200 # Appearance duration of 200ms for each image \n | |
create_video(image_folder, video_name, fps, appearance_duration) | |
""" | |
# Get a list of all image files in the folder | |
image_files = [f for f in os.listdir(image_folder) if f.endswith('.png')] | |
# Sort the image files based on the step number | |
image_files = sorted(image_files, key=lambda x: int(x.split('-')[1].split('.')[0])) | |
# Load the first image to get the video size | |
image = cv2.imread(os.path.join(image_folder, image_files[0])) | |
height, width, layers = image.shape | |
# Create a VideoWriter object | |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') # Specify the video codec | |
video = cv2.VideoWriter(video_name, fourcc, fps, (width, height)) | |
# Write each image to the video with customizable appearance duration | |
for image_file in image_files: | |
image = cv2.imread(os.path.join(image_folder, image_file)) | |
video.write(image) | |
if appearance_duration is not None: | |
# Calculate the number of frames for the desired appearance duration | |
num_frames = appearance_duration * fps // 1000 | |
for _ in range(num_frames): | |
video.write(image) | |
# Release the video writer | |
video.release() | |
def create_gif(image_folder, gif_name, fps, appearance_duration=None): | |
""" | |
Creates a GIF from a sequence of images sorted by step number, with customizable appearance duration. | |
Args: | |
image_folder (str): The path to the folder containing the images. | |
gif_name (str): The name of the output GIF file. | |
fps (int): The frames per second of the GIF. | |
appearance_duration (int, optional): The desired appearance duration for each image in milliseconds. | |
If None, the default duration based on frame rate is used. | |
Example: | |
image_folder = '/path/to/image/folder' | |
gif_name = 'output_animation.gif' | |
fps = 12 | |
appearance_duration = 300 # Appearance duration of 300ms for each image | |
create_gif(image_folder, gif_name, fps, appearance_duration) | |
""" | |
# Get a list of all image files in the folder | |
image_files = [f for f in os.listdir(image_folder) if f.endswith('.png')] | |
# Sort the image files based on the step number | |
image_files = sorted(image_files, key=lambda x: int(x.split('-')[1].split('.')[0])) | |
# Load the images into a list | |
images = [] | |
for file in image_files: | |
images.append(imageio.imread(os.path.join(image_folder, file))) | |
# Create a list to store the repeated images | |
repeated_images = [] | |
# Repeat each image for the desired duration | |
if appearance_duration is not None: | |
for image in images: | |
repeated_images.extend([image] * (appearance_duration * fps // 1000)) | |
else: | |
repeated_images = images # Default appearance duration (based on fps) | |
# Save the repeated images as a GIF | |
imageio.mimsave(gif_name, repeated_images, fps=fps) | |