Spaces:
Running
on
Zero
Running
on
Zero
This PR fixes the space
#10
by
Fabrice-TIERCELIN
- opened
app.py
CHANGED
@@ -2,16 +2,19 @@ import spaces
|
|
2 |
import os
|
3 |
import gradio as gr
|
4 |
from video_super_resolution.scripts.inference_sr import STAR_sr
|
|
|
|
|
|
|
5 |
|
6 |
# Example video and prompt pairs
|
7 |
examples = [
|
8 |
-
["examples/023_klingai_reedit.mp4", "The video shows a panda strumming a guitar on a rock by a tranquil lake at sunset. With its black-and-white fur, the panda sits against a backdrop of mountains and a vibrant sky painted in orange and pink hues. The serene scene highlights relaxation and whimsy, with the panda, guitar, and lake harmoniously positioned. The natural landscape's depth and perspective enhance the focus on the panda's peaceful interaction with the guitar.",
|
9 |
["examples/017_klingai_reedit.mp4", "The video depicts a majestic lion with eagle-like wings standing on a grassy hill against rolling green hills and a clear sky. The lion’s golden mane contrasts with the warm hues of the scene, and its intense gaze draws focus. The detailed, fully spread wings add a fantastical element. A 'PremiumBeat' watermark appears in the lower right, hinting at the image's source. The style blends realism with fantasy, showcasing the lion's mythical nature.", 4, 24, 250],
|
10 |
["examples/016_video.mp4", "The video is a black-and-white silent film featuring two men in wheelchairs on a pier. The foreground man, in a suit and hat, holds a sign reading 'HELP CRIPPLE.' The background shows a building and a boat, with early 20th-century clothing and image quality suggesting a narrative of disability and assistance.", 4, 24, 300],
|
11 |
]
|
12 |
|
13 |
# Define a GPU-decorated function for enhancement
|
14 |
-
@spaces.GPU()
|
15 |
def enhance_with_gpu(input_video, input_text, upscale, max_chunk_len, chunk_size):
|
16 |
"""在每次调用时创建新的 STAR_sr 实例,确保参数正确传递"""
|
17 |
star = STAR_sr(
|
@@ -49,7 +52,7 @@ def star_demo(result_dir="./tmp/"):
|
|
49 |
|
50 |
gr.Examples(
|
51 |
examples=examples,
|
52 |
-
inputs=[input_video, input_text],
|
53 |
outputs=[output_video],
|
54 |
fn=enhance_with_gpu, # Use the GPU-decorated function
|
55 |
cache_examples=False,
|
|
|
2 |
import os
|
3 |
import gradio as gr
|
4 |
from video_super_resolution.scripts.inference_sr import STAR_sr
|
5 |
+
from huggingface_hub import hf_hub_download
|
6 |
+
|
7 |
+
hf_hub_download(repo_id="SherryX/STAR", filename="I2VGen-XL-based/heavy_deg.pt", local_dir="pretrained_weight")
|
8 |
|
9 |
# Example video and prompt pairs
|
10 |
examples = [
|
11 |
+
["examples/023_klingai_reedit.mp4", "The video shows a panda strumming a guitar on a rock by a tranquil lake at sunset. With its black-and-white fur, the panda sits against a backdrop of mountains and a vibrant sky painted in orange and pink hues. The serene scene highlights relaxation and whimsy, with the panda, guitar, and lake harmoniously positioned. The natural landscape's depth and perspective enhance the focus on the panda's peaceful interaction with the guitar.", 2, 24, 250],
|
12 |
["examples/017_klingai_reedit.mp4", "The video depicts a majestic lion with eagle-like wings standing on a grassy hill against rolling green hills and a clear sky. The lion’s golden mane contrasts with the warm hues of the scene, and its intense gaze draws focus. The detailed, fully spread wings add a fantastical element. A 'PremiumBeat' watermark appears in the lower right, hinting at the image's source. The style blends realism with fantasy, showcasing the lion's mythical nature.", 4, 24, 250],
|
13 |
["examples/016_video.mp4", "The video is a black-and-white silent film featuring two men in wheelchairs on a pier. The foreground man, in a suit and hat, holds a sign reading 'HELP CRIPPLE.' The background shows a building and a boat, with early 20th-century clothing and image quality suggesting a narrative of disability and assistance.", 4, 24, 300],
|
14 |
]
|
15 |
|
16 |
# Define a GPU-decorated function for enhancement
|
17 |
+
@spaces.GPU(duration=180)
|
18 |
def enhance_with_gpu(input_video, input_text, upscale, max_chunk_len, chunk_size):
|
19 |
"""在每次调用时创建新的 STAR_sr 实例,确保参数正确传递"""
|
20 |
star = STAR_sr(
|
|
|
52 |
|
53 |
gr.Examples(
|
54 |
examples=examples,
|
55 |
+
inputs=[input_video, input_text, upscale, max_chunk_len, chunk_size],
|
56 |
outputs=[output_video],
|
57 |
fn=enhance_with_gpu, # Use the GPU-decorated function
|
58 |
cache_examples=False,
|
video_super_resolution/scripts/inference_sr.py
CHANGED
@@ -1,142 +1,142 @@
|
|
1 |
-
import os
|
2 |
-
import torch
|
3 |
-
from argparse import ArgumentParser, Namespace
|
4 |
-
import json
|
5 |
-
from typing import Any, Dict, List, Mapping, Tuple
|
6 |
-
from easydict import EasyDict
|
7 |
-
|
8 |
-
from video_to_video.video_to_video_model import VideoToVideo_sr
|
9 |
-
from video_to_video.utils.seed import setup_seed
|
10 |
-
from video_to_video.utils.logger import get_logger
|
11 |
-
from video_super_resolution.color_fix import adain_color_fix
|
12 |
-
|
13 |
-
from inference_utils import *
|
14 |
-
|
15 |
-
logger = get_logger()
|
16 |
-
|
17 |
-
|
18 |
-
class STAR_sr():
|
19 |
-
def __init__(self,
|
20 |
-
result_dir='./results/',
|
21 |
-
file_name='000_video.mp4',
|
22 |
-
model_path='./pretrained_weight',
|
23 |
-
solver_mode='fast',
|
24 |
-
steps=15,
|
25 |
-
guide_scale=7.5,
|
26 |
-
upscale=4,
|
27 |
-
max_chunk_len=32,
|
28 |
-
variant_info=None,
|
29 |
-
chunk_size=3,
|
30 |
-
):
|
31 |
-
self.model_path=model_path
|
32 |
-
logger.info('checkpoint_path: {}'.format(self.model_path))
|
33 |
-
|
34 |
-
self.result_dir = result_dir
|
35 |
-
self.file_name = file_name
|
36 |
-
os.makedirs(self.result_dir, exist_ok=True)
|
37 |
-
|
38 |
-
model_cfg = EasyDict(__name__='model_cfg')
|
39 |
-
model_cfg.model_path = self.model_path
|
40 |
-
model_cfg.chunk_size = chunk_size
|
41 |
-
self.model = VideoToVideo_sr(model_cfg)
|
42 |
-
|
43 |
-
steps = 15 if solver_mode == 'fast' else steps
|
44 |
-
self.solver_mode=solver_mode
|
45 |
-
self.steps=steps
|
46 |
-
self.guide_scale=guide_scale
|
47 |
-
self.upscale = upscale
|
48 |
-
self.max_chunk_len=max_chunk_len
|
49 |
-
self.variant_info=variant_info
|
50 |
-
|
51 |
-
def enhance_a_video(self, video_path, prompt):
|
52 |
-
logger.info('input video path: {}'.format(video_path))
|
53 |
-
text = prompt
|
54 |
-
logger.info('text: {}'.format(text))
|
55 |
-
caption = text + self.model.positive_prompt
|
56 |
-
|
57 |
-
input_frames, input_fps = load_video(video_path)
|
58 |
-
in_f_num = len(input_frames)
|
59 |
-
logger.info('input frames length: {}'.format(in_f_num))
|
60 |
-
logger.info('input fps: {}'.format(input_fps))
|
61 |
-
|
62 |
-
video_data = preprocess(input_frames)
|
63 |
-
_, _, h, w = video_data.shape
|
64 |
-
logger.info('input resolution: {}'.format((h, w)))
|
65 |
-
target_h, target_w = h * self.upscale, w * self.upscale # adjust_resolution(h, w, up_scale=4)
|
66 |
-
logger.info('target resolution: {}'.format((target_h, target_w)))
|
67 |
-
|
68 |
-
pre_data = {'video_data': video_data, 'y': caption}
|
69 |
-
pre_data['target_res'] = (target_h, target_w)
|
70 |
-
|
71 |
-
total_noise_levels = 900
|
72 |
-
setup_seed(666)
|
73 |
-
|
74 |
-
with torch.no_grad():
|
75 |
-
data_tensor = collate_fn(pre_data, 'cuda:0')
|
76 |
-
output = self.model.test(data_tensor, total_noise_levels, steps=self.steps, \
|
77 |
-
solver_mode=self.solver_mode, guide_scale=self.guide_scale, \
|
78 |
-
max_chunk_len=self.max_chunk_len
|
79 |
-
)
|
80 |
-
|
81 |
-
output = tensor2vid(output)
|
82 |
-
|
83 |
-
# Using color fix
|
84 |
-
output = adain_color_fix(output, video_data)
|
85 |
-
|
86 |
-
save_video(output, self.result_dir, self.file_name, fps=input_fps)
|
87 |
-
return os.path.join(self.result_dir, self.file_name)
|
88 |
-
|
89 |
-
|
90 |
-
def parse_args():
|
91 |
-
parser = ArgumentParser()
|
92 |
-
|
93 |
-
parser.add_argument("--input_path", required=True, type=str, help="input video path")
|
94 |
-
parser.add_argument("--save_dir", type=str, default='results', help="save directory")
|
95 |
-
parser.add_argument("--file_name", type=str, help="file name")
|
96 |
-
parser.add_argument("--model_path", type=str, default='./pretrained_weight/
|
97 |
-
parser.add_argument("--prompt", type=str, default='a good video', help="prompt")
|
98 |
-
parser.add_argument("--upscale", type=int, default=4, help='up-scale')
|
99 |
-
parser.add_argument("--max_chunk_len", type=int, default=32, help='max_chunk_len')
|
100 |
-
parser.add_argument("--variant_info", type=str, default=None, help='information of inference strategy')
|
101 |
-
|
102 |
-
parser.add_argument("--cfg", type=float, default=7.5)
|
103 |
-
parser.add_argument("--solver_mode", type=str, default='fast', help='fast | normal')
|
104 |
-
parser.add_argument("--steps", type=int, default=15)
|
105 |
-
|
106 |
-
return parser.parse_args()
|
107 |
-
|
108 |
-
def main():
|
109 |
-
|
110 |
-
args = parse_args()
|
111 |
-
|
112 |
-
input_path = args.input_path
|
113 |
-
prompt = args.prompt
|
114 |
-
model_path = args.model_path
|
115 |
-
save_dir = args.save_dir
|
116 |
-
file_name = args.file_name
|
117 |
-
upscale = args.upscale
|
118 |
-
max_chunk_len = args.max_chunk_len
|
119 |
-
|
120 |
-
steps = args.steps
|
121 |
-
solver_mode = args.solver_mode
|
122 |
-
guide_scale = args.cfg
|
123 |
-
|
124 |
-
assert solver_mode in ('fast', 'normal')
|
125 |
-
|
126 |
-
star_sr = STAR_sr(
|
127 |
-
result_dir=save_dir,
|
128 |
-
file_name=file_name, # new added
|
129 |
-
model_path=model_path,
|
130 |
-
solver_mode=solver_mode,
|
131 |
-
steps=steps,
|
132 |
-
guide_scale=guide_scale,
|
133 |
-
upscale=upscale,
|
134 |
-
max_chunk_len=max_chunk_len,
|
135 |
-
variant_info=None,
|
136 |
-
)
|
137 |
-
|
138 |
-
star_sr.enhance_a_video(input_path, prompt)
|
139 |
-
|
140 |
-
|
141 |
-
if __name__ == '__main__':
|
142 |
-
main()
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from argparse import ArgumentParser, Namespace
|
4 |
+
import json
|
5 |
+
from typing import Any, Dict, List, Mapping, Tuple
|
6 |
+
from easydict import EasyDict
|
7 |
+
|
8 |
+
from video_to_video.video_to_video_model import VideoToVideo_sr
|
9 |
+
from video_to_video.utils.seed import setup_seed
|
10 |
+
from video_to_video.utils.logger import get_logger
|
11 |
+
from video_super_resolution.color_fix import adain_color_fix
|
12 |
+
|
13 |
+
from inference_utils import *
|
14 |
+
|
15 |
+
logger = get_logger()
|
16 |
+
|
17 |
+
|
18 |
+
class STAR_sr():
|
19 |
+
def __init__(self,
|
20 |
+
result_dir='./results/',
|
21 |
+
file_name='000_video.mp4',
|
22 |
+
model_path='./pretrained_weight',
|
23 |
+
solver_mode='fast',
|
24 |
+
steps=15,
|
25 |
+
guide_scale=7.5,
|
26 |
+
upscale=4,
|
27 |
+
max_chunk_len=32,
|
28 |
+
variant_info=None,
|
29 |
+
chunk_size=3,
|
30 |
+
):
|
31 |
+
self.model_path=model_path
|
32 |
+
logger.info('checkpoint_path: {}'.format(self.model_path))
|
33 |
+
|
34 |
+
self.result_dir = result_dir
|
35 |
+
self.file_name = file_name
|
36 |
+
os.makedirs(self.result_dir, exist_ok=True)
|
37 |
+
|
38 |
+
model_cfg = EasyDict(__name__='model_cfg')
|
39 |
+
model_cfg.model_path = self.model_path
|
40 |
+
model_cfg.chunk_size = chunk_size
|
41 |
+
self.model = VideoToVideo_sr(model_cfg)
|
42 |
+
|
43 |
+
steps = 15 if solver_mode == 'fast' else steps
|
44 |
+
self.solver_mode=solver_mode
|
45 |
+
self.steps=steps
|
46 |
+
self.guide_scale=guide_scale
|
47 |
+
self.upscale = upscale
|
48 |
+
self.max_chunk_len=max_chunk_len
|
49 |
+
self.variant_info=variant_info
|
50 |
+
|
51 |
+
def enhance_a_video(self, video_path, prompt):
|
52 |
+
logger.info('input video path: {}'.format(video_path))
|
53 |
+
text = prompt
|
54 |
+
logger.info('text: {}'.format(text))
|
55 |
+
caption = text + self.model.positive_prompt
|
56 |
+
|
57 |
+
input_frames, input_fps = load_video(video_path)
|
58 |
+
in_f_num = len(input_frames)
|
59 |
+
logger.info('input frames length: {}'.format(in_f_num))
|
60 |
+
logger.info('input fps: {}'.format(input_fps))
|
61 |
+
|
62 |
+
video_data = preprocess(input_frames)
|
63 |
+
_, _, h, w = video_data.shape
|
64 |
+
logger.info('input resolution: {}'.format((h, w)))
|
65 |
+
target_h, target_w = h * self.upscale, w * self.upscale # adjust_resolution(h, w, up_scale=4)
|
66 |
+
logger.info('target resolution: {}'.format((target_h, target_w)))
|
67 |
+
|
68 |
+
pre_data = {'video_data': video_data, 'y': caption}
|
69 |
+
pre_data['target_res'] = (target_h, target_w)
|
70 |
+
|
71 |
+
total_noise_levels = 900
|
72 |
+
setup_seed(666)
|
73 |
+
|
74 |
+
with torch.no_grad():
|
75 |
+
data_tensor = collate_fn(pre_data, 'cuda:0')
|
76 |
+
output = self.model.test(data_tensor, total_noise_levels, steps=self.steps, \
|
77 |
+
solver_mode=self.solver_mode, guide_scale=self.guide_scale, \
|
78 |
+
max_chunk_len=self.max_chunk_len
|
79 |
+
)
|
80 |
+
|
81 |
+
output = tensor2vid(output)
|
82 |
+
|
83 |
+
# Using color fix
|
84 |
+
output = adain_color_fix(output, video_data)
|
85 |
+
|
86 |
+
save_video(output, self.result_dir, self.file_name, fps=input_fps)
|
87 |
+
return os.path.join(self.result_dir, self.file_name)
|
88 |
+
|
89 |
+
|
90 |
+
def parse_args():
|
91 |
+
parser = ArgumentParser()
|
92 |
+
|
93 |
+
parser.add_argument("--input_path", required=True, type=str, help="input video path")
|
94 |
+
parser.add_argument("--save_dir", type=str, default='results', help="save directory")
|
95 |
+
parser.add_argument("--file_name", type=str, help="file name")
|
96 |
+
parser.add_argument("--model_path", type=str, default='./pretrained_weight/I2VGen-XL-based/heavy_deg.pt', help="model path")
|
97 |
+
parser.add_argument("--prompt", type=str, default='a good video', help="prompt")
|
98 |
+
parser.add_argument("--upscale", type=int, default=4, help='up-scale')
|
99 |
+
parser.add_argument("--max_chunk_len", type=int, default=32, help='max_chunk_len')
|
100 |
+
parser.add_argument("--variant_info", type=str, default=None, help='information of inference strategy')
|
101 |
+
|
102 |
+
parser.add_argument("--cfg", type=float, default=7.5)
|
103 |
+
parser.add_argument("--solver_mode", type=str, default='fast', help='fast | normal')
|
104 |
+
parser.add_argument("--steps", type=int, default=15)
|
105 |
+
|
106 |
+
return parser.parse_args()
|
107 |
+
|
108 |
+
def main():
|
109 |
+
|
110 |
+
args = parse_args()
|
111 |
+
|
112 |
+
input_path = args.input_path
|
113 |
+
prompt = args.prompt
|
114 |
+
model_path = args.model_path
|
115 |
+
save_dir = args.save_dir
|
116 |
+
file_name = args.file_name
|
117 |
+
upscale = args.upscale
|
118 |
+
max_chunk_len = args.max_chunk_len
|
119 |
+
|
120 |
+
steps = args.steps
|
121 |
+
solver_mode = args.solver_mode
|
122 |
+
guide_scale = args.cfg
|
123 |
+
|
124 |
+
assert solver_mode in ('fast', 'normal')
|
125 |
+
|
126 |
+
star_sr = STAR_sr(
|
127 |
+
result_dir=save_dir,
|
128 |
+
file_name=file_name, # new added
|
129 |
+
model_path=model_path,
|
130 |
+
solver_mode=solver_mode,
|
131 |
+
steps=steps,
|
132 |
+
guide_scale=guide_scale,
|
133 |
+
upscale=upscale,
|
134 |
+
max_chunk_len=max_chunk_len,
|
135 |
+
variant_info=None,
|
136 |
+
)
|
137 |
+
|
138 |
+
star_sr.enhance_a_video(input_path, prompt)
|
139 |
+
|
140 |
+
|
141 |
+
if __name__ == '__main__':
|
142 |
+
main()
|
video_super_resolution/scripts/inference_sr.sh
CHANGED
@@ -45,7 +45,7 @@ for i in "${!mp4_files[@]}"; do
|
|
45 |
--solver_mode 'fast' \
|
46 |
--steps 15 \
|
47 |
--input_path "${mp4_file}" \
|
48 |
-
--model_path /mnt/bn/videodataset/VSR/pretrained_models/STAR/
|
49 |
--prompt "${line}" \
|
50 |
--upscale 4 \
|
51 |
--max_chunk_len ${frame_length} \
|
|
|
45 |
--solver_mode 'fast' \
|
46 |
--steps 15 \
|
47 |
--input_path "${mp4_file}" \
|
48 |
+
--model_path /mnt/bn/videodataset/VSR/pretrained_models/STAR/heavy_deg.pt \
|
49 |
--prompt "${line}" \
|
50 |
--upscale 4 \
|
51 |
--max_chunk_len ${frame_length} \
|
video_to_video/video_to_video_model.py
CHANGED
@@ -17,10 +17,10 @@ from diffusers import AutoencoderKLTemporalDecoder
|
|
17 |
import requests
|
18 |
|
19 |
def download_model(url, model_path):
|
20 |
-
if not os.path.exists(os.path.join(model_path, '
|
21 |
print(f"Model not found at {model_path}, downloading...")
|
22 |
response = requests.get(url, stream=True)
|
23 |
-
with open(os.path.join(model_path, '
|
24 |
for chunk in response.iter_content(chunk_size=1024):
|
25 |
if chunk:
|
26 |
f.write(chunk)
|
@@ -54,7 +54,7 @@ class VideoToVideo_sr():
|
|
54 |
download_model(model_url, cfg.model_path)
|
55 |
|
56 |
# 拼接完整路径
|
57 |
-
model_file_path = os.path.join(
|
58 |
print('model_file_path:', model_file_path)
|
59 |
|
60 |
# 加载模型
|
|
|
17 |
import requests
|
18 |
|
19 |
def download_model(url, model_path):
|
20 |
+
if not os.path.exists(os.path.join(model_path, 'heavy_deg.pt')):
|
21 |
print(f"Model not found at {model_path}, downloading...")
|
22 |
response = requests.get(url, stream=True)
|
23 |
+
with open(os.path.join(model_path, 'heavy_deg.pt'), 'wb') as f:
|
24 |
for chunk in response.iter_content(chunk_size=1024):
|
25 |
if chunk:
|
26 |
f.write(chunk)
|
|
|
54 |
download_model(model_url, cfg.model_path)
|
55 |
|
56 |
# 拼接完整路径
|
57 |
+
model_file_path = os.path.join('pretrained_weight', 'I2VGen-XL-based', 'heavy_deg.pt')
|
58 |
print('model_file_path:', model_file_path)
|
59 |
|
60 |
# 加载模型
|