rahul7star's picture
boilerplate
fcc02a2 verified
import os
import requests
import tqdm
from typing import List, Optional, TYPE_CHECKING
def img_root_path(img_id: str):
return os.path.dirname(os.path.dirname(img_id))
if TYPE_CHECKING:
from .dataset_tools_config_modules import DatasetSyncCollectionConfig
img_exts = ['.jpg', '.jpeg', '.webp', '.png']
class Photo:
def __init__(
self,
id,
host,
width,
height,
url,
filename
):
self.id = str(id)
self.host = host
self.width = width
self.height = height
self.url = url
self.filename = filename
def get_desired_size(img_width: int, img_height: int, min_width: int, min_height: int):
if img_width > img_height:
scale = min_height / img_height
else:
scale = min_width / img_width
new_width = int(img_width * scale)
new_height = int(img_height * scale)
return new_width, new_height
def get_pexels_images(config: 'DatasetSyncCollectionConfig') -> List[Photo]:
all_images = []
next_page = f"https://api.pexels.com/v1/collections/{config.collection_id}?page=1&per_page=80&type=photos"
while True:
response = requests.get(next_page, headers={
"Authorization": f"{config.api_key}"
})
response.raise_for_status()
data = response.json()
all_images.extend(data['media'])
if 'next_page' in data and data['next_page']:
next_page = data['next_page']
else:
break
photos = []
for image in all_images:
new_width, new_height = get_desired_size(image['width'], image['height'], config.min_width, config.min_height)
url = f"{image['src']['original']}?auto=compress&cs=tinysrgb&h={new_height}&w={new_width}"
filename = os.path.basename(image['src']['original'])
photos.append(Photo(
id=image['id'],
host="pexels",
width=image['width'],
height=image['height'],
url=url,
filename=filename
))
return photos
def get_unsplash_images(config: 'DatasetSyncCollectionConfig') -> List[Photo]:
headers = {
# "Authorization": f"Client-ID {UNSPLASH_ACCESS_KEY}"
"Authorization": f"Client-ID {config.api_key}"
}
# headers['Authorization'] = f"Bearer {token}"
url = f"https://api.unsplash.com/collections/{config.collection_id}/photos?page=1&per_page=30"
response = requests.get(url, headers=headers)
response.raise_for_status()
res_headers = response.headers
# parse the link header to get the next page
# 'Link': '<https://api.unsplash.com/collections/mIPWwLdfct8/photos?page=82>; rel="last", <https://api.unsplash.com/collections/mIPWwLdfct8/photos?page=2>; rel="next"'
has_next_page = False
if 'Link' in res_headers:
has_next_page = True
link_header = res_headers['Link']
link_header = link_header.split(',')
link_header = [link.strip() for link in link_header]
link_header = [link.split(';') for link in link_header]
link_header = [[link[0].strip('<>'), link[1].strip().strip('"')] for link in link_header]
link_header = {link[1]: link[0] for link in link_header}
# get page number from last url
last_page = link_header['rel="last']
last_page = last_page.split('?')[1]
last_page = last_page.split('&')
last_page = [param.split('=') for param in last_page]
last_page = {param[0]: param[1] for param in last_page}
last_page = int(last_page['page'])
all_images = response.json()
if has_next_page:
# assume we start on page 1, so we don't need to get it again
for page in tqdm.tqdm(range(2, last_page + 1)):
url = f"https://api.unsplash.com/collections/{config.collection_id}/photos?page={page}&per_page=30"
response = requests.get(url, headers=headers)
response.raise_for_status()
all_images.extend(response.json())
photos = []
for image in all_images:
new_width, new_height = get_desired_size(image['width'], image['height'], config.min_width, config.min_height)
url = f"{image['urls']['raw']}&w={new_width}"
filename = f"{image['id']}.jpg"
photos.append(Photo(
id=image['id'],
host="unsplash",
width=image['width'],
height=image['height'],
url=url,
filename=filename
))
return photos
def get_img_paths(dir_path: str):
os.makedirs(dir_path, exist_ok=True)
local_files = os.listdir(dir_path)
# remove non image files
local_files = [file for file in local_files if os.path.splitext(file)[1].lower() in img_exts]
# make full path
local_files = [os.path.join(dir_path, file) for file in local_files]
return local_files
def get_local_image_ids(dir_path: str):
os.makedirs(dir_path, exist_ok=True)
local_files = get_img_paths(dir_path)
# assuming local files are named after Unsplash IDs, e.g., 'abc123.jpg'
return set([os.path.basename(file).split('.')[0] for file in local_files])
def get_local_image_file_names(dir_path: str):
os.makedirs(dir_path, exist_ok=True)
local_files = get_img_paths(dir_path)
# assuming local files are named after Unsplash IDs, e.g., 'abc123.jpg'
return set([os.path.basename(file) for file in local_files])
def download_image(photo: Photo, dir_path: str, min_width: int = 1024, min_height: int = 1024):
img_width = photo.width
img_height = photo.height
if img_width < min_width or img_height < min_height:
raise ValueError(f"Skipping {photo.id} because it is too small: {img_width}x{img_height}")
img_response = requests.get(photo.url)
img_response.raise_for_status()
os.makedirs(dir_path, exist_ok=True)
filename = os.path.join(dir_path, photo.filename)
with open(filename, 'wb') as file:
file.write(img_response.content)
def update_caption(img_path: str):
# if the caption is a txt file, convert it to a json file
filename_no_ext = os.path.splitext(os.path.basename(img_path))[0]
# see if it exists
if os.path.exists(os.path.join(os.path.dirname(img_path), f"{filename_no_ext}.json")):
# todo add poi and what not
return # we have a json file
caption = ""
# see if txt file exists
if os.path.exists(os.path.join(os.path.dirname(img_path), f"{filename_no_ext}.txt")):
# read it
with open(os.path.join(os.path.dirname(img_path), f"{filename_no_ext}.txt"), 'r') as file:
caption = file.read()
# write json file
with open(os.path.join(os.path.dirname(img_path), f"{filename_no_ext}.json"), 'w') as file:
file.write(f'{{"caption": "{caption}"}}')
# delete txt file
os.remove(os.path.join(os.path.dirname(img_path), f"{filename_no_ext}.txt"))
# def equalize_img(img_path: str):
# input_path = img_path
# output_path = os.path.join(img_root_path(img_path), COLOR_CORRECTED_DIR, os.path.basename(img_path))
# os.makedirs(os.path.dirname(output_path), exist_ok=True)
# process_img(
# img_path=input_path,
# output_path=output_path,
# equalize=True,
# max_size=2056,
# white_balance=False,
# gamma_correction=False,
# strength=0.6,
# )
# def annotate_depth(img_path: str):
# # make fake args
# args = argparse.Namespace()
# args.annotator = "midas"
# args.res = 1024
#
# img = cv2.imread(img_path)
# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
#
# output = annotate(img, args)
#
# output = output.astype('uint8')
# output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
#
# os.makedirs(os.path.dirname(img_path), exist_ok=True)
# output_path = os.path.join(img_root_path(img_path), DEPTH_DIR, os.path.basename(img_path))
#
# cv2.imwrite(output_path, output)
# def invert_depth(img_path: str):
# img = cv2.imread(img_path)
# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# # invert the colors
# img = cv2.bitwise_not(img)
#
# os.makedirs(os.path.dirname(img_path), exist_ok=True)
# output_path = os.path.join(img_root_path(img_path), INVERTED_DEPTH_DIR, os.path.basename(img_path))
# cv2.imwrite(output_path, img)
#
# # update our list of raw images
# raw_images = get_img_paths(raw_dir)
#
# # update raw captions
# for image_id in tqdm.tqdm(raw_images, desc="Updating raw captions"):
# update_caption(image_id)
#
# # equalize images
# for img_path in tqdm.tqdm(raw_images, desc="Equalizing images"):
# if img_path not in eq_images:
# equalize_img(img_path)
#
# # update our list of eq images
# eq_images = get_img_paths(eq_dir)
# # update eq captions
# for image_id in tqdm.tqdm(eq_images, desc="Updating eq captions"):
# update_caption(image_id)
#
# # annotate depth
# depth_dir = os.path.join(root_dir, DEPTH_DIR)
# depth_images = get_img_paths(depth_dir)
# for img_path in tqdm.tqdm(eq_images, desc="Annotating depth"):
# if img_path not in depth_images:
# annotate_depth(img_path)
#
# depth_images = get_img_paths(depth_dir)
#
# # invert depth
# inv_depth_dir = os.path.join(root_dir, INVERTED_DEPTH_DIR)
# inv_depth_images = get_img_paths(inv_depth_dir)
# for img_path in tqdm.tqdm(depth_images, desc="Inverting depth"):
# if img_path not in inv_depth_images:
# invert_depth(img_path)