Delete dataset.py
Browse files- dataset.py +0 -92
dataset.py
DELETED
@@ -1,92 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import pandas as pd
|
3 |
-
import numpy as np
|
4 |
-
from PIL import Image
|
5 |
-
import torch
|
6 |
-
from torch.utils.data import Dataset, DataLoader
|
7 |
-
import json
|
8 |
-
import random
|
9 |
-
|
10 |
-
def image_resize(img, max_size=512):
|
11 |
-
w, h = img.size
|
12 |
-
if w >= h:
|
13 |
-
new_w = max_size
|
14 |
-
new_h = int((max_size / w) * h)
|
15 |
-
else:
|
16 |
-
new_h = max_size
|
17 |
-
new_w = int((max_size / h) * w)
|
18 |
-
return img.resize((new_w, new_h))
|
19 |
-
|
20 |
-
def c_crop(image):
|
21 |
-
width, height = image.size
|
22 |
-
new_size = min(width, height)
|
23 |
-
left = (width - new_size) / 2
|
24 |
-
top = (height - new_size) / 2
|
25 |
-
right = (width + new_size) / 2
|
26 |
-
bottom = (height + new_size) / 2
|
27 |
-
return image.crop((left, top, right, bottom))
|
28 |
-
|
29 |
-
def crop_to_aspect_ratio(image, ratio="16:9"):
|
30 |
-
width, height = image.size
|
31 |
-
ratio_map = {
|
32 |
-
"16:9": (16, 9),
|
33 |
-
"4:3": (4, 3),
|
34 |
-
"1:1": (1, 1)
|
35 |
-
}
|
36 |
-
target_w, target_h = ratio_map[ratio]
|
37 |
-
target_ratio_value = target_w / target_h
|
38 |
-
|
39 |
-
current_ratio = width / height
|
40 |
-
|
41 |
-
if current_ratio > target_ratio_value:
|
42 |
-
new_width = int(height * target_ratio_value)
|
43 |
-
offset = (width - new_width) // 2
|
44 |
-
crop_box = (offset, 0, offset + new_width, height)
|
45 |
-
else:
|
46 |
-
new_height = int(width / target_ratio_value)
|
47 |
-
offset = (height - new_height) // 2
|
48 |
-
crop_box = (0, offset, width, offset + new_height)
|
49 |
-
|
50 |
-
cropped_img = image.crop(crop_box)
|
51 |
-
return cropped_img
|
52 |
-
|
53 |
-
|
54 |
-
class CustomImageDataset(Dataset):
|
55 |
-
def __init__(self, img_dir, img_size=512, caption_type='json', random_ratio=False):
|
56 |
-
self.images = [os.path.join(img_dir, i) for i in os.listdir(img_dir) if '.jpg' in i or '.png' in i]
|
57 |
-
self.images.sort()
|
58 |
-
self.img_size = img_size
|
59 |
-
self.caption_type = caption_type
|
60 |
-
self.random_ratio = random_ratio
|
61 |
-
|
62 |
-
def __len__(self):
|
63 |
-
return len(self.images)
|
64 |
-
|
65 |
-
def __getitem__(self, idx):
|
66 |
-
try:
|
67 |
-
img = Image.open(self.images[idx]).convert('RGB')
|
68 |
-
if self.random_ratio:
|
69 |
-
ratio = random.choice(["16:9", "default", "1:1", "4:3"])
|
70 |
-
if ratio != "default":
|
71 |
-
img = crop_to_aspect_ratio(img, ratio)
|
72 |
-
img = image_resize(img, self.img_size)
|
73 |
-
w, h = img.size
|
74 |
-
new_w = (w // 32) * 32
|
75 |
-
new_h = (h // 32) * 32
|
76 |
-
img = img.resize((new_w, new_h))
|
77 |
-
img = torch.from_numpy((np.array(img) / 127.5) - 1)
|
78 |
-
img = img.permute(2, 0, 1)
|
79 |
-
json_path = self.images[idx].split('.')[0] + '.' + self.caption_type
|
80 |
-
if self.caption_type == "json":
|
81 |
-
prompt = json.load(open(json_path))['caption']
|
82 |
-
else:
|
83 |
-
prompt = open(json_path).read()
|
84 |
-
return img, prompt
|
85 |
-
except Exception as e:
|
86 |
-
print(e)
|
87 |
-
return self.__getitem__(random.randint(0, len(self.images) - 1))
|
88 |
-
|
89 |
-
|
90 |
-
def loader(train_batch_size, num_workers, **args):
|
91 |
-
dataset = CustomImageDataset(**args)
|
92 |
-
return DataLoader(dataset, batch_size=train_batch_size, num_workers=num_workers, shuffle=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|