import os import random import copy from PIL import Image import numpy as np import json from torch.utils.data import Dataset from torchvision.transforms import ToPILImage, Compose, RandomCrop, ToTensor from utils.image_utils import random_augmentation, crop_img from utils.degradation_utils import Degradation class DerainDehazeDataset(Dataset): def __init__(self, args, img, text_prompt, task="derain"): super(DerainDehazeDataset, self).__init__() self.args = args self.toTensor = ToTensor() self.img = img self.text_prompt = text_prompt def __getitem__(self, idx): degraded_inp = self.img clean_path = "" degradation = "" text_prompt = self.text_prompt degraded_img = crop_img(np.array(degraded_inp.convert('RGB')), base=16) clean_img = crop_img(np.array(degraded_inp.convert('RGB')), base=16) clean_img, degraded_img = self.toTensor(clean_img), self.toTensor(degraded_img) degraded_name = [""] return [degraded_name], degradation, degraded_img, clean_img, text_prompt def __len__(self): return 1