Junjie96 commited on
Commit
4ca31c3
1 Parent(s): fd4b3ba

Update src/util.py

Browse files
Files changed (1) hide show
  1. src/util.py +16 -2
src/util.py CHANGED
@@ -1,6 +1,7 @@
1
  import concurrent.futures
2
  import io
3
  import os
 
4
 
5
  import numpy as np
6
  import oss2
@@ -20,6 +21,15 @@ oss_path = "hejunjie.hjj/TransferAnythingHF"
20
  oss_path_img_gallery = "hejunjie.hjj/TransferAnythingHF_img_gallery"
21
 
22
 
 
 
 
 
 
 
 
 
 
23
  def download_img_pil(index, img_url):
24
  # print(img_url)
25
  r = requests.get(img_url, stream=True)
@@ -50,11 +60,13 @@ def download_images(img_urls, batch_size):
50
 
51
  def upload_np_2_oss(input_image, name="cache.png", gallery=False):
52
  assert name.lower().endswith((".png", ".jpg")), name
 
 
53
  imgByteArr = io.BytesIO()
54
  if name.lower().endswith(".png"):
55
- Image.fromarray(input_image).save(imgByteArr, format="PNG")
56
  else:
57
- Image.fromarray(input_image).save(imgByteArr, format="JPEG", quality=95)
58
  imgByteArr = imgByteArr.getvalue()
59
 
60
  if gallery:
@@ -62,8 +74,10 @@ def upload_np_2_oss(input_image, name="cache.png", gallery=False):
62
  else:
63
  path = oss_path
64
 
 
65
  bucket.put_object(path + "/" + name, imgByteArr) # data为数据,可以是图片
66
  ret = bucket.sign_url('GET', path + "/" + name, 60 * 60 * 24) # 返回值为链接,参数依次为,方法/oss上文件路径/过期时间(s)
 
67
  del imgByteArr
68
  return ret
69
 
 
1
  import concurrent.futures
2
  import io
3
  import os
4
+ import time
5
 
6
  import numpy as np
7
  import oss2
 
21
  oss_path_img_gallery = "hejunjie.hjj/TransferAnythingHF_img_gallery"
22
 
23
 
24
+ def resize(image, short_side_length=768):
25
+ width, height = image.size
26
+ ratio = short_side_length / min(width, height)
27
+ new_width = int(width * ratio)
28
+ new_height = int(height * ratio)
29
+ resized_image = image.resize((new_width, new_height))
30
+ return resized_image
31
+
32
+
33
  def download_img_pil(index, img_url):
34
  # print(img_url)
35
  r = requests.get(img_url, stream=True)
 
60
 
61
  def upload_np_2_oss(input_image, name="cache.png", gallery=False):
62
  assert name.lower().endswith((".png", ".jpg")), name
63
+ if name.lower().endswith(".png"):
64
+ name = name[:-4] + ".jpg"
65
  imgByteArr = io.BytesIO()
66
  if name.lower().endswith(".png"):
67
+ resize(Image.fromarray(input_image)).save(imgByteArr, format="PNG")
68
  else:
69
+ resize(Image.fromarray(input_image)).save(imgByteArr, format="JPEG", quality=95)
70
  imgByteArr = imgByteArr.getvalue()
71
 
72
  if gallery:
 
74
  else:
75
  path = oss_path
76
 
77
+ start_time = time.perf_counter()
78
  bucket.put_object(path + "/" + name, imgByteArr) # data为数据,可以是图片
79
  ret = bucket.sign_url('GET', path + "/" + name, 60 * 60 * 24) # 返回值为链接,参数依次为,方法/oss上文件路径/过期时间(s)
80
+ logger.info(f"upload cost: {time.perf_counter() - start_time} s.")
81
  del imgByteArr
82
  return ret
83