Fabrice-TIERCELIN commited on
Commit
602cbe0
·
verified ·
1 Parent(s): 0c79af5

Update video_super_resolution/scripts/inference_sr.py

Browse files
video_super_resolution/scripts/inference_sr.py CHANGED
@@ -1,142 +1,56 @@
1
- import os
2
- import torch
3
- from argparse import ArgumentParser, Namespace
4
- import json
5
- from typing import Any, Dict, List, Mapping, Tuple
6
- from easydict import EasyDict
7
-
8
- from video_to_video.video_to_video_model import VideoToVideo_sr
9
- from video_to_video.utils.seed import setup_seed
10
- from video_to_video.utils.logger import get_logger
11
- from video_super_resolution.color_fix import adain_color_fix
12
-
13
- from inference_utils import *
14
-
15
- logger = get_logger()
16
-
17
-
18
- class STAR_sr():
19
- def __init__(self,
20
- result_dir='./results/',
21
- file_name='000_video.mp4',
22
- model_path='./pretrained_weight',
23
- solver_mode='fast',
24
- steps=15,
25
- guide_scale=7.5,
26
- upscale=4,
27
- max_chunk_len=32,
28
- variant_info=None,
29
- chunk_size=3,
30
- ):
31
- self.model_path=model_path
32
- logger.info('checkpoint_path: {}'.format(self.model_path))
33
-
34
- self.result_dir = result_dir
35
- self.file_name = file_name
36
- os.makedirs(self.result_dir, exist_ok=True)
37
-
38
- model_cfg = EasyDict(__name__='model_cfg')
39
- model_cfg.model_path = self.model_path
40
- model_cfg.chunk_size = chunk_size
41
- self.model = VideoToVideo_sr(model_cfg)
42
-
43
- steps = 15 if solver_mode == 'fast' else steps
44
- self.solver_mode=solver_mode
45
- self.steps=steps
46
- self.guide_scale=guide_scale
47
- self.upscale = upscale
48
- self.max_chunk_len=max_chunk_len
49
- self.variant_info=variant_info
50
-
51
- def enhance_a_video(self, video_path, prompt):
52
- logger.info('input video path: {}'.format(video_path))
53
- text = prompt
54
- logger.info('text: {}'.format(text))
55
- caption = text + self.model.positive_prompt
56
-
57
- input_frames, input_fps = load_video(video_path)
58
- in_f_num = len(input_frames)
59
- logger.info('input frames length: {}'.format(in_f_num))
60
- logger.info('input fps: {}'.format(input_fps))
61
-
62
- video_data = preprocess(input_frames)
63
- _, _, h, w = video_data.shape
64
- logger.info('input resolution: {}'.format((h, w)))
65
- target_h, target_w = h * self.upscale, w * self.upscale # adjust_resolution(h, w, up_scale=4)
66
- logger.info('target resolution: {}'.format((target_h, target_w)))
67
-
68
- pre_data = {'video_data': video_data, 'y': caption}
69
- pre_data['target_res'] = (target_h, target_w)
70
-
71
- total_noise_levels = 900
72
- setup_seed(666)
73
-
74
- with torch.no_grad():
75
- data_tensor = collate_fn(pre_data, 'cuda:0')
76
- output = self.model.test(data_tensor, total_noise_levels, steps=self.steps, \
77
- solver_mode=self.solver_mode, guide_scale=self.guide_scale, \
78
- max_chunk_len=self.max_chunk_len
79
- )
80
-
81
- output = tensor2vid(output)
82
-
83
- # Using color fix
84
- output = adain_color_fix(output, video_data)
85
-
86
- save_video(output, self.result_dir, self.file_name, fps=input_fps)
87
- return os.path.join(self.result_dir, self.file_name)
88
-
89
-
90
- def parse_args():
91
- parser = ArgumentParser()
92
-
93
- parser.add_argument("--input_path", required=True, type=str, help="input video path")
94
- parser.add_argument("--save_dir", type=str, default='results', help="save directory")
95
- parser.add_argument("--file_name", type=str, help="file name")
96
- parser.add_argument("--model_path", type=str, default='./pretrained_weight/model.pt', help="model path")
97
- parser.add_argument("--prompt", type=str, default='a good video', help="prompt")
98
- parser.add_argument("--upscale", type=int, default=4, help='up-scale')
99
- parser.add_argument("--max_chunk_len", type=int, default=32, help='max_chunk_len')
100
- parser.add_argument("--variant_info", type=str, default=None, help='information of inference strategy')
101
-
102
- parser.add_argument("--cfg", type=float, default=7.5)
103
- parser.add_argument("--solver_mode", type=str, default='fast', help='fast | normal')
104
- parser.add_argument("--steps", type=int, default=15)
105
-
106
- return parser.parse_args()
107
-
108
- def main():
109
-
110
- args = parse_args()
111
-
112
- input_path = args.input_path
113
- prompt = args.prompt
114
- model_path = args.model_path
115
- save_dir = args.save_dir
116
- file_name = args.file_name
117
- upscale = args.upscale
118
- max_chunk_len = args.max_chunk_len
119
-
120
- steps = args.steps
121
- solver_mode = args.solver_mode
122
- guide_scale = args.cfg
123
-
124
- assert solver_mode in ('fast', 'normal')
125
-
126
- star_sr = STAR_sr(
127
- result_dir=save_dir,
128
- file_name=file_name, # new added
129
- model_path=model_path,
130
- solver_mode=solver_mode,
131
- steps=steps,
132
- guide_scale=guide_scale,
133
- upscale=upscale,
134
- max_chunk_len=max_chunk_len,
135
- variant_info=None,
136
- )
137
-
138
- star_sr.enhance_a_video(input_path, prompt)
139
-
140
-
141
- if __name__ == '__main__':
142
- main()
 
1
+ #!/bin/bash
2
+
3
+ # Folder paths
4
+ video_folder_path='./input/video'
5
+ txt_file_path='./input/text/prompt.txt'
6
+
7
+ # Get all .mp4 files in the folder using find to handle special characters
8
+ mapfile -t mp4_files < <(find "$video_folder_path" -type f -name "*.mp4")
9
+
10
+ # Print the list of MP4 files
11
+ echo "MP4 files to be processed:"
12
+ for mp4_file in "${mp4_files[@]}"; do
13
+ echo "$mp4_file"
14
+ done
15
+
16
+ # Read lines from the text file, skipping empty lines
17
+ mapfile -t lines < <(grep -v '^\s*$' "$txt_file_path")
18
+
19
+ # List of frame counts
20
+ frame_length=32
21
+
22
+ # Debugging output
23
+ echo "Number of MP4 files: ${#mp4_files[@]}"
24
+ echo "Number of lines in the text file: ${#lines[@]}"
25
+
26
+ # Ensure the number of video files matches the number of lines
27
+ if [ ${#mp4_files[@]} -ne ${#lines[@]} ]; then
28
+ echo "Number of MP4 files and lines in the text file do not match."
29
+ exit 1
30
+ fi
31
+
32
+ # Loop through video files and corresponding lines
33
+ for i in "${!mp4_files[@]}"; do
34
+ mp4_file="${mp4_files[$i]}"
35
+ line="${lines[$i]}"
36
+
37
+ # Extract the filename without the extension
38
+ file_name=$(basename "$mp4_file" .mp4)
39
+
40
+ echo "Processing video file: $mp4_file with prompt: $line"
41
+
42
+ # Run Python script with parameters
43
+ python \
44
+ ./video_super_resolution/scripts/inference_sr.py \
45
+ --solver_mode 'fast' \
46
+ --steps 15 \
47
+ --input_path "${mp4_file}" \
48
+ --model_path /mnt/bn/videodataset/VSR/pretrained_models/STAR/heavy_deg.pt \
49
+ --prompt "${line}" \
50
+ --upscale 4 \
51
+ --max_chunk_len ${frame_length} \
52
+ --file_name "${file_name}.mp4" \
53
+ --save_dir ./results
54
+ done
55
+
56
+ echo "All videos processed successfully."