DigiP-AI commited on
Commit
63a9d79
·
verified ·
1 Parent(s): d5398d8

Delete dataset.py

Browse files
Files changed (1) hide show
  1. dataset.py +0 -92
dataset.py DELETED
@@ -1,92 +0,0 @@
1
- import os
2
- import pandas as pd
3
- import numpy as np
4
- from PIL import Image
5
- import torch
6
- from torch.utils.data import Dataset, DataLoader
7
- import json
8
- import random
9
-
10
- def image_resize(img, max_size=512):
11
- w, h = img.size
12
- if w >= h:
13
- new_w = max_size
14
- new_h = int((max_size / w) * h)
15
- else:
16
- new_h = max_size
17
- new_w = int((max_size / h) * w)
18
- return img.resize((new_w, new_h))
19
-
20
- def c_crop(image):
21
- width, height = image.size
22
- new_size = min(width, height)
23
- left = (width - new_size) / 2
24
- top = (height - new_size) / 2
25
- right = (width + new_size) / 2
26
- bottom = (height + new_size) / 2
27
- return image.crop((left, top, right, bottom))
28
-
29
- def crop_to_aspect_ratio(image, ratio="16:9"):
30
- width, height = image.size
31
- ratio_map = {
32
- "16:9": (16, 9),
33
- "4:3": (4, 3),
34
- "1:1": (1, 1)
35
- }
36
- target_w, target_h = ratio_map[ratio]
37
- target_ratio_value = target_w / target_h
38
-
39
- current_ratio = width / height
40
-
41
- if current_ratio > target_ratio_value:
42
- new_width = int(height * target_ratio_value)
43
- offset = (width - new_width) // 2
44
- crop_box = (offset, 0, offset + new_width, height)
45
- else:
46
- new_height = int(width / target_ratio_value)
47
- offset = (height - new_height) // 2
48
- crop_box = (0, offset, width, offset + new_height)
49
-
50
- cropped_img = image.crop(crop_box)
51
- return cropped_img
52
-
53
-
54
- class CustomImageDataset(Dataset):
55
- def __init__(self, img_dir, img_size=512, caption_type='json', random_ratio=False):
56
- self.images = [os.path.join(img_dir, i) for i in os.listdir(img_dir) if '.jpg' in i or '.png' in i]
57
- self.images.sort()
58
- self.img_size = img_size
59
- self.caption_type = caption_type
60
- self.random_ratio = random_ratio
61
-
62
- def __len__(self):
63
- return len(self.images)
64
-
65
- def __getitem__(self, idx):
66
- try:
67
- img = Image.open(self.images[idx]).convert('RGB')
68
- if self.random_ratio:
69
- ratio = random.choice(["16:9", "default", "1:1", "4:3"])
70
- if ratio != "default":
71
- img = crop_to_aspect_ratio(img, ratio)
72
- img = image_resize(img, self.img_size)
73
- w, h = img.size
74
- new_w = (w // 32) * 32
75
- new_h = (h // 32) * 32
76
- img = img.resize((new_w, new_h))
77
- img = torch.from_numpy((np.array(img) / 127.5) - 1)
78
- img = img.permute(2, 0, 1)
79
- json_path = self.images[idx].split('.')[0] + '.' + self.caption_type
80
- if self.caption_type == "json":
81
- prompt = json.load(open(json_path))['caption']
82
- else:
83
- prompt = open(json_path).read()
84
- return img, prompt
85
- except Exception as e:
86
- print(e)
87
- return self.__getitem__(random.randint(0, len(self.images) - 1))
88
-
89
-
90
- def loader(train_batch_size, num_workers, **args):
91
- dataset = CustomImageDataset(**args)
92
- return DataLoader(dataset, batch_size=train_batch_size, num_workers=num_workers, shuffle=True)