Spaces:
Running
Running
Update src/util.py
Browse files- src/util.py +16 -2
src/util.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import concurrent.futures
|
2 |
import io
|
3 |
import os
|
|
|
4 |
|
5 |
import numpy as np
|
6 |
import oss2
|
@@ -20,6 +21,15 @@ oss_path = "hejunjie.hjj/TransferAnythingHF"
|
|
20 |
oss_path_img_gallery = "hejunjie.hjj/TransferAnythingHF_img_gallery"
|
21 |
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
def download_img_pil(index, img_url):
|
24 |
# print(img_url)
|
25 |
r = requests.get(img_url, stream=True)
|
@@ -50,11 +60,13 @@ def download_images(img_urls, batch_size):
|
|
50 |
|
51 |
def upload_np_2_oss(input_image, name="cache.png", gallery=False):
|
52 |
assert name.lower().endswith((".png", ".jpg")), name
|
|
|
|
|
53 |
imgByteArr = io.BytesIO()
|
54 |
if name.lower().endswith(".png"):
|
55 |
-
Image.fromarray(input_image).save(imgByteArr, format="PNG")
|
56 |
else:
|
57 |
-
Image.fromarray(input_image).save(imgByteArr, format="JPEG", quality=95)
|
58 |
imgByteArr = imgByteArr.getvalue()
|
59 |
|
60 |
if gallery:
|
@@ -62,8 +74,10 @@ def upload_np_2_oss(input_image, name="cache.png", gallery=False):
|
|
62 |
else:
|
63 |
path = oss_path
|
64 |
|
|
|
65 |
bucket.put_object(path + "/" + name, imgByteArr) # data为数据,可以是图片
|
66 |
ret = bucket.sign_url('GET', path + "/" + name, 60 * 60 * 24) # 返回值为链接,参数依次为,方法/oss上文件路径/过期时间(s)
|
|
|
67 |
del imgByteArr
|
68 |
return ret
|
69 |
|
|
|
1 |
import concurrent.futures
|
2 |
import io
|
3 |
import os
|
4 |
+
import time
|
5 |
|
6 |
import numpy as np
|
7 |
import oss2
|
|
|
21 |
oss_path_img_gallery = "hejunjie.hjj/TransferAnythingHF_img_gallery"
|
22 |
|
23 |
|
24 |
+
def resize(image, short_side_length=768):
|
25 |
+
width, height = image.size
|
26 |
+
ratio = short_side_length / min(width, height)
|
27 |
+
new_width = int(width * ratio)
|
28 |
+
new_height = int(height * ratio)
|
29 |
+
resized_image = image.resize((new_width, new_height))
|
30 |
+
return resized_image
|
31 |
+
|
32 |
+
|
33 |
def download_img_pil(index, img_url):
|
34 |
# print(img_url)
|
35 |
r = requests.get(img_url, stream=True)
|
|
|
60 |
|
61 |
def upload_np_2_oss(input_image, name="cache.png", gallery=False):
|
62 |
assert name.lower().endswith((".png", ".jpg")), name
|
63 |
+
if name.lower().endswith(".png"):
|
64 |
+
name = name[:-4] + ".jpg"
|
65 |
imgByteArr = io.BytesIO()
|
66 |
if name.lower().endswith(".png"):
|
67 |
+
resize(Image.fromarray(input_image)).save(imgByteArr, format="PNG")
|
68 |
else:
|
69 |
+
resize(Image.fromarray(input_image)).save(imgByteArr, format="JPEG", quality=95)
|
70 |
imgByteArr = imgByteArr.getvalue()
|
71 |
|
72 |
if gallery:
|
|
|
74 |
else:
|
75 |
path = oss_path
|
76 |
|
77 |
+
start_time = time.perf_counter()
|
78 |
bucket.put_object(path + "/" + name, imgByteArr) # data为数据,可以是图片
|
79 |
ret = bucket.sign_url('GET', path + "/" + name, 60 * 60 * 24) # 返回值为链接,参数依次为,方法/oss上文件路径/过期时间(s)
|
80 |
+
logger.info(f"upload cost: {time.perf_counter() - start_time} s.")
|
81 |
del imgByteArr
|
82 |
return ret
|
83 |
|