Upload gradio_s3diff.py
Browse files- src/gradio_s3diff.py +2 -0
src/gradio_s3diff.py
CHANGED
@@ -17,6 +17,7 @@ from de_net import DEResNet
|
|
17 |
from s3diff_tile import S3Diff
|
18 |
from torchvision import transforms
|
19 |
from utils.wavelet_color import wavelet_color_fix, adain_color_fix
|
|
|
20 |
|
21 |
tensor_transforms = transforms.Compose([
|
22 |
transforms.ToTensor(),
|
@@ -65,6 +66,7 @@ device = "cuda"
|
|
65 |
net_sr.to(device, dtype=weight_dtype)
|
66 |
net_de.to(device, dtype=weight_dtype)
|
67 |
|
|
|
68 |
@torch.no_grad()
|
69 |
def process(
|
70 |
input_image: Image.Image,
|
|
|
17 |
from s3diff_tile import S3Diff
|
18 |
from torchvision import transforms
|
19 |
from utils.wavelet_color import wavelet_color_fix, adain_color_fix
|
20 |
+
from spaces import GPU # 追加
|
21 |
|
22 |
tensor_transforms = transforms.Compose([
|
23 |
transforms.ToTensor(),
|
|
|
66 |
net_sr.to(device, dtype=weight_dtype)
|
67 |
net_de.to(device, dtype=weight_dtype)
|
68 |
|
69 |
+
@GPU(duration=60) # GPUを利用する関数にデコレーターを追加
|
70 |
@torch.no_grad()
|
71 |
def process(
|
72 |
input_image: Image.Image,
|