|
import argparse |
|
import logging |
|
import random |
|
|
|
import cv2 |
|
import jsonlines |
|
import numpy as np |
|
import requests |
|
from datasets import load_dataset |
|
from PIL import Image |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser( |
|
description="Example of a data preprocessing script." |
|
) |
|
parser.add_argument( |
|
"--train_data_dir", |
|
type=str, |
|
required=True, |
|
help="The directory to store the dataset", |
|
) |
|
parser.add_argument( |
|
"--cache_dir", |
|
type=str, |
|
required=True, |
|
help="The directory to store cache", |
|
) |
|
parser.add_argument( |
|
"--max_train_samples", |
|
type=int, |
|
default=None, |
|
help="number of examples in the dataset", |
|
) |
|
parser.add_argument( |
|
"--num_proc", |
|
type=int, |
|
default=1, |
|
help="number of processors to use in `dataset.map()`", |
|
) |
|
|
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
|
|
def filter_function(example): |
|
if example["clip_similarity_vitb32"] < 0.3: |
|
return False |
|
if example["watermark_score"] > 0.4: |
|
return False |
|
if example["aesthetic_score_laion_v2"] < 6.0: |
|
return False |
|
return True |
|
|
|
|
|
def filter_dataset(dataset, max_train_samples): |
|
small_dataset = dataset.select(range(max_train_samples)).filter(filter_function) |
|
return small_dataset |
|
|
|
|
|
if __name__ == "__main__": |
|
args = parse_args() |
|
|
|
|
|
dataset = load_dataset( |
|
"kakaobrain/coyo-700m", |
|
cache_dir=args.cache_dir, |
|
split="train", |
|
) |
|
|
|
|
|
filter_ratio = len(filter_dataset(dataset, 20000)) / 20000 |
|
|
|
|
|
|
|
|
|
max_train_samples = int(args.max_train_samples / filter_ratio / 0.8) |
|
|
|
|
|
small_dataset = filter_dataset(dataset, max_train_samples) |
|
|
|
def preprocess_and_save(example): |
|
image_url = example["url"] |
|
try: |
|
|
|
image = Image.open(requests.get(image_url, stream=True, timeout=5).raw) |
|
image_path = f"{args.train_data_dir}/images/{example['id']}.png" |
|
image.save(image_path) |
|
|
|
|
|
processed_image = np.array(image) |
|
|
|
|
|
|
|
|
|
threholds = ( |
|
random.randint(0, 255), |
|
random.randint(0, 255), |
|
) |
|
processed_image = cv2.Canny(processed_image, min(threholds), max(threholds)) |
|
processed_image = processed_image[:, :, None] |
|
processed_image = np.concatenate( |
|
[processed_image, processed_image, processed_image], axis=2 |
|
) |
|
processed_image = Image.fromarray(processed_image) |
|
processed_image_path = ( |
|
f"{args.train_data_dir}/processed_images/{example['id']}.png" |
|
) |
|
processed_image.save(processed_image_path) |
|
|
|
|
|
meta = { |
|
"image": image_path, |
|
"conditioning_image": processed_image_path, |
|
"caption": example["text"], |
|
} |
|
with jsonlines.open( |
|
f"{args.train_data_dir}/meta.jsonl", "a" |
|
) as writer: |
|
writer.write(meta) |
|
|
|
except Exception as e: |
|
logger.error(f"Failed to process image{image_url}: {str(e)}") |
|
|
|
|
|
small_dataset.map(preprocess_and_save, num_proc=args.num_proc) |
|
|
|
print(f"created data folder at: {args.train_data_dir}") |
|
|