duongttr commited on
Commit
62ef5f4
·
1 Parent(s): d3551a1

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +17 -0
  2. .gitignore +137 -0
  3. README.md +1 -7
  4. UI.py +81 -0
  5. app.py +215 -0
  6. checkpoints/colornet.pth +3 -0
  7. checkpoints/embed_net.pth +3 -0
  8. checkpoints/nonlocal_net.pth +3 -0
  9. cmd.txt +21 -0
  10. cmd_ddp.txt +20 -0
  11. docs/.gitignore +0 -0
  12. environment.yml +0 -0
  13. examples.zip +3 -0
  14. examples/bear/ref.jpg +0 -0
  15. examples/bear/video.mp4 +3 -0
  16. examples/boat/ref.jpg +0 -0
  17. examples/boat/video.mp4 +0 -0
  18. examples/cows/ref.jpg +0 -0
  19. examples/cows/video.mp4 +3 -0
  20. examples/flamingo/ref.jpg +0 -0
  21. examples/flamingo/video.mp4 +3 -0
  22. gradio_cached_examples/13/log.csv +5 -0
  23. gradio_cached_examples/13/output/003c3114319372a78bf2f812ebaf0041afa280fb/output_video.mp4 +3 -0
  24. gradio_cached_examples/13/output/74c76e483235b7e80665e32d7fcdcc3da2be7644/output_video.mp4 +0 -0
  25. gradio_cached_examples/13/output/7969adca8ae38cb3b38ff8e7bb54688d942c7bc8/output_video.mp4 +3 -0
  26. gradio_cached_examples/13/output/e6d6153dedeb9fec586b3241311cc49dbc17bc85/output_video.mp4 +0 -0
  27. inputs/video.mp4/000000000.jpg +0 -0
  28. inputs/video.mp4/000000001.jpg +0 -0
  29. inputs/video.mp4/000000002.jpg +0 -0
  30. inputs/video.mp4/000000003.jpg +0 -0
  31. inputs/video.mp4/000000004.jpg +0 -0
  32. inputs/video.mp4/000000005.jpg +0 -0
  33. inputs/video.mp4/000000006.jpg +0 -0
  34. inputs/video.mp4/000000007.jpg +0 -0
  35. inputs/video.mp4/000000008.jpg +0 -0
  36. inputs/video.mp4/000000009.jpg +0 -0
  37. inputs/video.mp4/000000010.jpg +0 -0
  38. inputs/video.mp4/000000011.jpg +0 -0
  39. inputs/video.mp4/000000012.jpg +0 -0
  40. inputs/video.mp4/000000013.jpg +0 -0
  41. inputs/video.mp4/000000014.jpg +0 -0
  42. inputs/video.mp4/000000015.jpg +0 -0
  43. inputs/video.mp4/000000016.jpg +0 -0
  44. inputs/video.mp4/000000017.jpg +0 -0
  45. inputs/video.mp4/000000018.jpg +0 -0
  46. inputs/video.mp4/000000019.jpg +0 -0
  47. inputs/video.mp4/000000020.jpg +0 -0
  48. inputs/video.mp4/000000021.jpg +0 -0
  49. inputs/video.mp4/000000022.jpg +0 -0
  50. inputs/video.mp4/000000023.jpg +0 -0
.gitattributes CHANGED
@@ -33,3 +33,20 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ EvalDataset/clips/bear/output_video.mp4 filter=lfs diff=lfs merge=lfs -text
37
+ EvalDataset/clips/bear/output_video_gray.mp4 filter=lfs diff=lfs merge=lfs -text
38
+ EvalDataset/clips/boat/output_video_gray.mp4 filter=lfs diff=lfs merge=lfs -text
39
+ EvalDataset/clips/cows/output_video.mp4 filter=lfs diff=lfs merge=lfs -text
40
+ EvalDataset/clips/cows/output_video_gray.mp4 filter=lfs diff=lfs merge=lfs -text
41
+ EvalDataset/clips/dog/output_video.mp4 filter=lfs diff=lfs merge=lfs -text
42
+ EvalDataset/clips/flamingo/output_video_gray.mp4 filter=lfs diff=lfs merge=lfs -text
43
+ EvalDataset/ref/goat/0000.jpg filter=lfs diff=lfs merge=lfs -text
44
+ EvalDataset/ref/hockey/0000.jpg filter=lfs diff=lfs merge=lfs -text
45
+ EvalDataset/ref/horsejump-high/0000.jpg filter=lfs diff=lfs merge=lfs -text
46
+ EvalDataset/ref/motorbike/0000.jpg filter=lfs diff=lfs merge=lfs -text
47
+ EvalDataset/ref/surf/0000.jpg filter=lfs diff=lfs merge=lfs -text
48
+ examples/bear/video.mp4 filter=lfs diff=lfs merge=lfs -text
49
+ examples/cows/video.mp4 filter=lfs diff=lfs merge=lfs -text
50
+ examples/flamingo/video.mp4 filter=lfs diff=lfs merge=lfs -text
51
+ gradio_cached_examples/13/output/003c3114319372a78bf2f812ebaf0041afa280fb/output_video.mp4 filter=lfs diff=lfs merge=lfs -text
52
+ gradio_cached_examples/13/output/7969adca8ae38cb3b38ff8e7bb54688d942c7bc8/output_video.mp4 filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ checkpoints/
2
+ wandb/
3
+ .vscode
4
+ .DS_Store
5
+ *ckpt*/
6
+ # Custom
7
+ *.pt
8
+ data/local
9
+ # Byte-compiled / optimized / DLL files
10
+ __pycache__/
11
+ *.py[cod]
12
+ *$py.class
13
+
14
+ # C extensions
15
+ *.so
16
+
17
+ # Distribution / packaging
18
+ .Python
19
+ build/
20
+ develop-eggs/
21
+ dist/
22
+ downloads/
23
+ eggs/
24
+ .eggs/
25
+ lib/
26
+ lib64/
27
+ parts/
28
+ sdist/
29
+ var/
30
+ wheels/
31
+ pip-wheel-metadata/
32
+ share/python-wheels/
33
+ *.egg-info/
34
+ .installed.cfg
35
+ *.egg
36
+ MANIFEST
37
+
38
+ # PyInstaller
39
+ # Usually these files are written by a python script from a template
40
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
41
+ *.manifest
42
+ *.spec
43
+
44
+ # Installer logs
45
+ pip-log.txt
46
+ pip-delete-this-directory.txt
47
+
48
+ # Unit test / coverage reports
49
+ htmlcov/
50
+ .tox/
51
+ .nox/
52
+ .coverage
53
+ .coverage.*
54
+ .cache
55
+ nosetests.xml
56
+ coverage.xml
57
+ *.cover
58
+ *.py,cover
59
+ .hypothesis/
60
+ .pytest_cache/
61
+
62
+ # Translations
63
+ *.mo
64
+ *.pot
65
+
66
+ # Django stuff:
67
+ *.log
68
+ local_settings.py
69
+ db.sqlite3
70
+ db.sqlite3-journal
71
+
72
+ # Flask stuff:
73
+ instance/
74
+ .webassets-cache
75
+
76
+ # Scrapy stuff:
77
+ .scrapy
78
+
79
+ # Sphinx documentation
80
+ docs/_build/
81
+
82
+ # PyBuilder
83
+ target/
84
+
85
+ # Jupyter Notebook
86
+ .ipynb_checkpoints
87
+
88
+ # IPython
89
+ profile_default/
90
+ ipython_config.py
91
+
92
+ # pyenv
93
+ .python-version
94
+
95
+ # pipenv
96
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
97
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
98
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
99
+ # install all needed dependencies.
100
+ #Pipfile.lock
101
+
102
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
103
+ __pypackages__/
104
+
105
+ # Celery stuff
106
+ celerybeat-schedule
107
+ celerybeat.pid
108
+
109
+ # SageMath parsed files
110
+ *.sage.py
111
+
112
+ # Environments
113
+ .env
114
+ .venv
115
+ env/
116
+ venv/
117
+ ENV/
118
+ env.bak/
119
+ venv.bak/
120
+
121
+ # Spyder project settings
122
+ .spyderproject
123
+ .spyproject
124
+
125
+ # Rope project settings
126
+ .ropeproject
127
+
128
+ # mkdocs documentation
129
+ /site
130
+
131
+ # mypy
132
+ .mypy_cache/
133
+ .dmypy.json
134
+ dmypy.json
135
+
136
+ # Pyre type checker
137
+ .pyre/
README.md CHANGED
@@ -1,12 +1,6 @@
1
  ---
2
  title: ViTExCo
3
- emoji: 👀
4
- colorFrom: gray
5
- colorTo: green
6
  sdk: gradio
7
  sdk_version: 3.40.1
8
- app_file: app.py
9
- pinned: false
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: ViTExCo
3
+ app_file: app.py
 
 
4
  sdk: gradio
5
  sdk_version: 3.40.1
 
 
6
  ---
 
 
UI.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ import torchvision.transforms as transforms
4
+ from streamlit_image_comparison import image_comparison
5
+ import numpy as np
6
+ import torch
7
+ import torchvision
8
+
9
+ ######################################### Utils ########################################
10
+ video_extensions = ["mp4"]
11
+ image_extensions = ["png", "jpg"]
12
+
13
+
14
+ def check_type(file_name: str):
15
+ for image_extension in image_extensions:
16
+ if file_name.endswith(image_extension):
17
+ return "image"
18
+ for video_extension in video_extensions:
19
+ if file_name.endswith(video_extension):
20
+ return "video"
21
+ return None
22
+
23
+
24
+ transform = transforms.Compose(
25
+ [transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]
26
+ )
27
+
28
+
29
+ ###################################### Load model ######################################
30
+ @st.cache_resource
31
+ def load_model():
32
+ model = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=True)
33
+ model.eval()
34
+ return model
35
+
36
+
37
+ model = load_model()
38
+ ########################################## UI ##########################################
39
+ st.title("Colorization")
40
+
41
+ uploaded_file = st.file_uploader("Upload grayscale image or video", type=image_extensions + video_extensions)
42
+ if uploaded_file:
43
+ # Image
44
+ if check_type(file_name=uploaded_file.name) == "image":
45
+ image = np.array(Image.open(uploaded_file), dtype=np.float32)
46
+
47
+ input_tensor = torchvision.transforms.functional.normalize(
48
+ torch.tensor(image).permute(2, 0, 1),
49
+ mean=[0.485, 0.456, 0.406],
50
+ std=[0.229, 0.224, 0.225],
51
+ ).unsqueeze(0)
52
+ process_button = st.button("Process")
53
+ if process_button:
54
+ with st.spinner("Từ từ coi..."):
55
+ prediction = model(input_tensor)
56
+ segment = prediction["out"][0].permute(1, 2, 0)
57
+ segment = segment.detach().numpy()
58
+
59
+ st.image(segment)
60
+ st.image(image)
61
+
62
+ image_comparison(
63
+ img1=image,
64
+ img2=np.array(segment),
65
+ label1="Grayscale",
66
+ label2="Colorized",
67
+ make_responsive=True,
68
+ show_labels=True,
69
+ )
70
+ # Video
71
+ else:
72
+ # video = open(uploaded_file.name)
73
+ st.video("https://youtu.be/dQw4w9WgXcQ")
74
+
75
+ hide_menu_style = """
76
+ <style>
77
+ #MainMenu {visibility: hidden; }
78
+ footer {visibility: hidden;}
79
+ </style>
80
+ """
81
+ st.markdown(hide_menu_style, unsafe_allow_html=True)
app.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import shutil
3
+ import os
4
+ import argparse
5
+ import torch
6
+ import glob
7
+ from tqdm import tqdm
8
+ from PIL import Image
9
+ from collections import OrderedDict
10
+ from src.models.vit.config import load_config
11
+ import torchvision.transforms as transforms
12
+ import cv2
13
+ from skimage import io
14
+
15
+ from src.models.CNN.ColorVidNet import GeneralColorVidNet
16
+ from src.models.vit.embed import GeneralEmbedModel
17
+ from src.models.CNN.NonlocalNet import GeneralWarpNet
18
+ from src.models.CNN.FrameColor import frame_colorization
19
+ from src.utils import (
20
+ RGB2Lab,
21
+ ToTensor,
22
+ Normalize,
23
+ uncenter_l,
24
+ tensor_lab2rgb,
25
+ SquaredPadding,
26
+ UnpaddingSquare
27
+ )
28
+
29
+ import gradio as gr
30
+
31
+ def load_params(ckpt_file):
32
+ params = torch.load(ckpt_file, map_location=device)
33
+ new_params = []
34
+ for key, value in params.items():
35
+ new_params.append((key, value))
36
+ return OrderedDict(new_params)
37
+
38
+ def custom_transform(transforms, img):
39
+ for transform in transforms:
40
+ if isinstance(transform, SquaredPadding):
41
+ img,padding=transform(img, return_paddings=True)
42
+ else:
43
+ img = transform(img)
44
+ return img.to(device), padding
45
+
46
+ def save_frames(predicted_rgb, video_name, frame_name):
47
+ if predicted_rgb is not None:
48
+ predicted_rgb = np.clip(predicted_rgb, 0, 255).astype(np.uint8)
49
+ # frame_path_parts = frame_path.split(os.sep)
50
+ # if os.path.exists(os.path.join(OUTPUT_RESULT_PATH, frame_path_parts[-2])):
51
+ # shutil.rmtree(os.path.join(OUTPUT_RESULT_PATH, frame_path_parts[-2]))
52
+ # os.makedirs(os.path.join(OUTPUT_RESULT_PATH, frame_path_parts[-2]), exist_ok=True)
53
+ predicted_rgb = np.transpose(predicted_rgb, (1,2,0))
54
+ pil_img = Image.fromarray(predicted_rgb)
55
+ pil_img.save(os.path.join(OUTPUT_RESULT_PATH, video_name, frame_name))
56
+
57
+ def extract_frames_from_video(video_path):
58
+ cap = cv2.VideoCapture(video_path)
59
+ fps = cap.get(cv2.CAP_PROP_FPS)
60
+
61
+ # remove if exists folder
62
+ output_frames_path = os.path.join(INPUT_VIDEO_FRAMES_PATH, os.path.basename(video_path))
63
+ if os.path.exists(output_frames_path):
64
+ shutil.rmtree(output_frames_path)
65
+
66
+ # make new folder
67
+ os.makedirs(output_frames_path)
68
+
69
+ currentframe = 0
70
+ frame_path_list = []
71
+ while(True):
72
+
73
+ # reading from frame
74
+ ret,frame = cap.read()
75
+
76
+ if ret:
77
+ name = os.path.join(output_frames_path, f'{currentframe:09d}.jpg')
78
+ frame_path_list.append(name)
79
+ cv2.imwrite(name, frame)
80
+ currentframe += 1
81
+ else:
82
+ break
83
+
84
+ cap.release()
85
+ cv2.destroyAllWindows()
86
+
87
+ return frame_path_list, fps
88
+
89
+ def combine_frames_from_folder(frames_list_path, fps = 30):
90
+ frames_list = glob.glob(f'{frames_list_path}/*.jpg')
91
+ frames_list.sort()
92
+
93
+ sample_shape = cv2.imread(frames_list[0]).shape
94
+
95
+ output_video_path = os.path.join(frames_list_path, 'output_video.mp4')
96
+ out = cv2.VideoWriter(output_video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (sample_shape[1], sample_shape[0]))
97
+ for filename in frames_list:
98
+ img = cv2.imread(filename)
99
+ out.write(img)
100
+
101
+ out.release()
102
+ return output_video_path
103
+
104
+
105
+ def upscale_image(I_current_rgb, I_current_ab_predict):
106
+ H, W = I_current_rgb.size
107
+ high_lab_transforms = [
108
+ SquaredPadding(target_size=max(H,W)),
109
+ RGB2Lab(),
110
+ ToTensor(),
111
+ Normalize()
112
+ ]
113
+ # current_frame_pil_rgb = Image.fromarray(np.clip(I_current_rgb.squeeze(0).permute(1,2,0).cpu().numpy() * 255, 0, 255).astype('uint8'))
114
+ high_lab_current, paddings = custom_transform(high_lab_transforms, I_current_rgb)
115
+ high_lab_current = torch.unsqueeze(high_lab_current,dim=0).to(device)
116
+ high_l_current = high_lab_current[:, 0:1, :, :]
117
+ high_ab_current = high_lab_current[:, 1:3, :, :]
118
+ upsampler = torch.nn.Upsample(scale_factor=max(H,W)/224,mode="bilinear")
119
+ high_ab_predict = upsampler(I_current_ab_predict)
120
+ I_predict_rgb = tensor_lab2rgb(torch.cat((uncenter_l(high_l_current), high_ab_predict), dim=1))
121
+ upadded = UnpaddingSquare()
122
+ I_predict_rgb = upadded(I_predict_rgb, paddings)
123
+ return I_predict_rgb
124
+
125
+ def colorize_video(video_path, ref_np):
126
+ frames_list, fps = extract_frames_from_video(video_path)
127
+
128
+ frame_ref = Image.fromarray(ref_np).convert("RGB")
129
+ I_last_lab_predict = None
130
+ IB_lab, IB_paddings = custom_transform(transforms, frame_ref)
131
+ IB_lab = IB_lab.unsqueeze(0).to(device)
132
+ IB_l = IB_lab[:, 0:1, :, :]
133
+ IB_ab = IB_lab[:, 1:3, :, :]
134
+
135
+ with torch.no_grad():
136
+ I_reference_lab = IB_lab
137
+ I_reference_l = I_reference_lab[:, 0:1, :, :]
138
+ I_reference_ab = I_reference_lab[:, 1:3, :, :]
139
+ I_reference_rgb = tensor_lab2rgb(torch.cat((uncenter_l(I_reference_l), I_reference_ab), dim=1)).to(device)
140
+ features_B = embed_net(I_reference_rgb)
141
+
142
+ video_path_parts = frames_list[0].split(os.sep)
143
+
144
+ if os.path.exists(os.path.join(OUTPUT_RESULT_PATH, video_path_parts[-2])):
145
+ shutil.rmtree(os.path.join(OUTPUT_RESULT_PATH, video_path_parts[-2]))
146
+ os.makedirs(os.path.join(OUTPUT_RESULT_PATH, video_path_parts[-2]), exist_ok=True)
147
+
148
+ for frame_path in tqdm(frames_list):
149
+ curr_frame = Image.open(frame_path).convert("RGB")
150
+ IA_lab, IA_paddings = custom_transform(transforms, curr_frame)
151
+ IA_lab = IA_lab.unsqueeze(0).to(device)
152
+ IA_l = IA_lab[:, 0:1, :, :]
153
+ IA_ab = IA_lab[:, 1:3, :, :]
154
+
155
+ if I_last_lab_predict is None:
156
+ I_last_lab_predict = torch.zeros_like(IA_lab).to(device)
157
+
158
+ with torch.no_grad():
159
+ I_current_lab = IA_lab
160
+ I_current_ab_predict, _ = frame_colorization(
161
+ IA_l,
162
+ I_reference_lab,
163
+ I_last_lab_predict,
164
+ features_B,
165
+ embed_net,
166
+ nonlocal_net,
167
+ colornet,
168
+ luminance_noise=0,
169
+ temperature=1e-10,
170
+ joint_training=False
171
+ )
172
+ I_last_lab_predict = torch.cat((IA_l, I_current_ab_predict), dim=1)
173
+
174
+ # IA_predict_rgb = tensor_lab2rgb(torch.cat((uncenter_l(IA_l), I_current_ab_predict), dim=1))
175
+ IA_predict_rgb = upscale_image(curr_frame, I_current_ab_predict)
176
+ #IA_predict_rgb = torch.nn.functional.upsample_bilinear(IA_predict_rgb, scale_factor=2)
177
+ save_frames(IA_predict_rgb.squeeze(0).cpu().numpy() * 255, video_path_parts[-2], os.path.basename(frame_path))
178
+ return combine_frames_from_folder(os.path.join(OUTPUT_RESULT_PATH, video_path_parts[-2]), fps)
179
+
180
+ if __name__ == '__main__':
181
+ # Init global variables
182
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
183
+ INPUT_VIDEO_FRAMES_PATH = 'inputs'
184
+ OUTPUT_RESULT_PATH = 'outputs'
185
+ weight_path = 'checkpoints'
186
+
187
+ embed_net=GeneralEmbedModel(pretrained_model="swin-tiny", device=device).to(device)
188
+ nonlocal_net = GeneralWarpNet(feature_channel=128).to(device)
189
+ colornet=GeneralColorVidNet(7).to(device)
190
+
191
+ embed_net.eval()
192
+ nonlocal_net.eval()
193
+ colornet.eval()
194
+
195
+ # Load weights
196
+ # embed_net_params = load_params(os.path.join(weight_path, "embed_net.pth"))
197
+ nonlocal_net_params = load_params(os.path.join(weight_path, "nonlocal_net.pth"))
198
+ colornet_params = load_params(os.path.join(weight_path, "colornet.pth"))
199
+
200
+ # embed_net.load_state_dict(embed_net_params, strict=True)
201
+ nonlocal_net.load_state_dict(nonlocal_net_params, strict=True)
202
+ colornet.load_state_dict(colornet_params, strict=True)
203
+
204
+ transforms = [SquaredPadding(target_size=224),
205
+ RGB2Lab(),
206
+ ToTensor(),
207
+ Normalize()]
208
+
209
+ examples = [[vid, ref] for vid, ref in zip(sorted(glob.glob('examples/*/*.mp4')), sorted(glob.glob('examples/*/*.jpg')))]
210
+ demo = gr.Interface(colorize_video,
211
+ inputs=[gr.Video(), gr.Image()],
212
+ outputs="playable_video",
213
+ examples=examples,
214
+ cache_examples=True)
215
+ demo.launch()
checkpoints/colornet.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5257ae325e292cd5fb2eff47095e1c4e4815455bd5fb6dc5ed2ee2b923172875
3
+ size 131239411
checkpoints/embed_net.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fc711755a75c43025dabe9407cbd11d164eaa9e21f26430d0c16c7493410d902
3
+ size 110352261
checkpoints/nonlocal_net.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b94c6990f20088bc3cc3fe0b29a6d52e6e746b915c506f0cd349fc6ad6197e72
3
+ size 73189765
cmd.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python train.py --video_data_root_list datasets/images/images \
2
+ --flow_data_root_list datasets/flow_fp16/flow_fp16 \
3
+ --mask_data_root_list datasets/pgm/pgm \
4
+ --data_root_imagenet datasets/imgnet \
5
+ --annotation_file_path datasets/final_annot.csv \
6
+ --imagenet_pairs_file datasets/pairs.txt \
7
+ --gpu_ids 0 \
8
+ --workers 12 \
9
+ --batch_size 2 \
10
+ --real_reference_probability 0.99 \
11
+ --weight_contextual 1 \
12
+ --weight_perceptual 0.1 \
13
+ --weight_smoothness 5 \
14
+ --weight_gan 0.9 \
15
+ --weight_consistent 0.1 \
16
+ --use_wandb True \
17
+ --wandb_token "f05d31e6b15339b1cfc5ee1c77fe51f66fc3ea9e" \
18
+ --wandb_name "vit_tiny_patch16_384_nofeat" \
19
+ --checkpoint_step 500 \
20
+ --epoch_train_discriminator 3 \
21
+ --epoch 20
cmd_ddp.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ !torchrun --nnodes=1 --nproc_per_node=2 train_ddp.py --video_data_root_list $video_data_root_list \
2
+ --flow_data_root_list $flow_data_root_list \
3
+ --mask_data_root_list $mask_data_root_list \
4
+ --data_root_imagenet $data_root_imagenet \
5
+ --annotation_file_path $annotation_file_path \
6
+ --imagenet_pairs_file $imagenet_pairs_file \
7
+ --gpu_ids "0,1" \
8
+ --workers 2 \
9
+ --batch_size 2 \
10
+ --real_reference_probability 0.99 \
11
+ --weight_contextual 1 \
12
+ --weight_perceptual 0.1 \
13
+ --weight_smoothness 5 \
14
+ --weight_gan 0.9 \
15
+ --weight_consistent 0.1 \
16
+ --wandb_token "165e7148081f263b423722115e2ad40fa5339ecf" \
17
+ --wandb_name "vit_tiny_patch16_384_nofeat" \
18
+ --checkpoint_step 2000 \
19
+ --epoch_train_discriminator 2 \
20
+ --epoch 10
docs/.gitignore ADDED
File without changes
environment.yml ADDED
File without changes
examples.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bd4531bd3abdec6df90efb0d19fadd54284bdc70d5edfff19752a205159eb4db
3
+ size 6955837
examples/bear/ref.jpg ADDED
examples/bear/video.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cb4cec5064873a4616f78bdb653830683a4842b2a5cfd0665b395cff4d120d04
3
+ size 1263445
examples/boat/ref.jpg ADDED
examples/boat/video.mp4 ADDED
Binary file (853 kB). View file
 
examples/cows/ref.jpg ADDED
examples/cows/video.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1ac08603d719cd7a8d71fac76c9318d3e8f1e516e9b3c2a06323a0e4e78f6410
3
+ size 2745681
examples/flamingo/ref.jpg ADDED
examples/flamingo/video.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5a103fd4991a00e419e5236b885fe9d220704ba0a6ac794c87aaa3f62a4f1561
3
+ size 1239570
gradio_cached_examples/13/log.csv ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ output,flag,username,timestamp
2
+ /content/ViTExCo/gradio_cached_examples/13/output/003c3114319372a78bf2f812ebaf0041afa280fb/output_video.mp4,,,2023-08-15 09:45:37.897615
3
+ /content/ViTExCo/gradio_cached_examples/13/output/e6d6153dedeb9fec586b3241311cc49dbc17bc85/output_video.mp4,,,2023-08-15 09:46:01.048997
4
+ /content/ViTExCo/gradio_cached_examples/13/output/7969adca8ae38cb3b38ff8e7bb54688d942c7bc8/output_video.mp4,,,2023-08-15 09:46:34.503322
5
+ /content/ViTExCo/gradio_cached_examples/13/output/74c76e483235b7e80665e32d7fcdcc3da2be7644/output_video.mp4,,,2023-08-15 09:46:58.088903
gradio_cached_examples/13/output/003c3114319372a78bf2f812ebaf0041afa280fb/output_video.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b5ab666998e14fb00281a90f8801753eca001a432641ae2770007a8336b4c64e
3
+ size 1213824
gradio_cached_examples/13/output/74c76e483235b7e80665e32d7fcdcc3da2be7644/output_video.mp4 ADDED
Binary file (914 kB). View file
 
gradio_cached_examples/13/output/7969adca8ae38cb3b38ff8e7bb54688d942c7bc8/output_video.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7c367dab34e596f7f0fed34c7e2384525de2ba1824b410d0770bdbd17bc9e72a
3
+ size 1793060
gradio_cached_examples/13/output/e6d6153dedeb9fec586b3241311cc49dbc17bc85/output_video.mp4 ADDED
Binary file (673 kB). View file
 
inputs/video.mp4/000000000.jpg ADDED
inputs/video.mp4/000000001.jpg ADDED
inputs/video.mp4/000000002.jpg ADDED
inputs/video.mp4/000000003.jpg ADDED
inputs/video.mp4/000000004.jpg ADDED
inputs/video.mp4/000000005.jpg ADDED
inputs/video.mp4/000000006.jpg ADDED
inputs/video.mp4/000000007.jpg ADDED
inputs/video.mp4/000000008.jpg ADDED
inputs/video.mp4/000000009.jpg ADDED
inputs/video.mp4/000000010.jpg ADDED
inputs/video.mp4/000000011.jpg ADDED
inputs/video.mp4/000000012.jpg ADDED
inputs/video.mp4/000000013.jpg ADDED
inputs/video.mp4/000000014.jpg ADDED
inputs/video.mp4/000000015.jpg ADDED
inputs/video.mp4/000000016.jpg ADDED
inputs/video.mp4/000000017.jpg ADDED
inputs/video.mp4/000000018.jpg ADDED
inputs/video.mp4/000000019.jpg ADDED
inputs/video.mp4/000000020.jpg ADDED
inputs/video.mp4/000000021.jpg ADDED
inputs/video.mp4/000000022.jpg ADDED
inputs/video.mp4/000000023.jpg ADDED