Spaces:
Runtime error
Runtime error
File size: 4,934 Bytes
8d14048 2a841b2 8d14048 2a841b2 8d14048 4fed504 8d14048 5d906f6 8d14048 7419fe8 8d14048 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
from __future__ import annotations
import datetime
import os
import pathlib
import shlex
import shutil
import subprocess
import sys
import gradio as gr
import slugify
import torch
import huggingface_hub
from huggingface_hub import HfApi
from omegaconf import OmegaConf
ORIGINAL_SPACE_ID = 'BAAI/vid2vid-zero'
SPACE_ID = os.getenv('SPACE_ID', ORIGINAL_SPACE_ID)
class Runner:
def __init__(self, hf_token: str | None = None):
self.hf_token = hf_token
self.checkpoint_dir = pathlib.Path('checkpoints')
self.checkpoint_dir.mkdir(exist_ok=True)
def download_base_model(self, base_model_id: str, token=None) -> str:
model_dir = self.checkpoint_dir / base_model_id
org_name = base_model_id.split('/')[0]
org_dir = self.checkpoint_dir / org_name
if not model_dir.exists():
org_dir.mkdir(exist_ok=True)
print(f'https://huggingface.co/{base_model_id}')
if token == None:
subprocess.run(shlex.split(f'git lfs install'), cwd=org_dir)
subprocess.run(shlex.split(
f'git lfs clone https://huggingface.co/{base_model_id}'),
cwd=org_dir)
return model_dir.as_posix()
else:
temp_path = huggingface_hub.snapshot_download(base_model_id, use_auth_token=token)
print(temp_path, org_dir)
# subprocess.run(shlex.split(f'mv {temp_path} {model_dir.as_posix()}'))
# return model_dir.as_posix()
return temp_path
def join_model_library_org(self, token: str) -> None:
subprocess.run(
shlex.split(
f'curl -X POST -H "Authorization: Bearer {token}" -H "Content-Type: application/json" {URL_TO_JOIN_MODEL_LIBRARY_ORG}'
))
def run_vid2vid_zero(
self,
model_path: str,
input_video: str,
prompt: str,
n_sample_frames: int,
sample_start_idx: int,
sample_frame_rate: int,
validation_prompt: str,
guidance_scale: float,
resolution: str,
seed: int,
remove_gpu_after_running: bool,
input_token: str = None,
) -> str:
if not torch.cuda.is_available():
raise gr.Error('CUDA is not available.')
if input_video is None:
raise gr.Error('You need to upload a video.')
if not prompt:
raise gr.Error('The input prompt is missing.')
if not validation_prompt:
raise gr.Error('The validation prompt is missing.')
resolution = int(resolution)
n_sample_frames = int(n_sample_frames)
sample_start_idx = int(sample_start_idx)
sample_frame_rate = int(sample_frame_rate)
repo_dir = pathlib.Path(__file__).parent
prompt_path = prompt.replace(' ', '_')
output_dir = repo_dir / 'outputs' / prompt_path
output_dir.mkdir(parents=True, exist_ok=True)
config = OmegaConf.load('configs/black-swan.yaml')
config.pretrained_model_path = self.download_base_model(model_path, token=input_token)
# we remove null-inversion & use fp16 for fast inference on web demo
config.mixed_precision = "fp16"
config.validation_data.use_null_inv = False
config.output_dir = output_dir.as_posix()
config.input_data.video_path = input_video.name # type: ignore
config.input_data.prompt = prompt
config.input_data.n_sample_frames = n_sample_frames
config.input_data.width = resolution
config.input_data.height = resolution
config.input_data.sample_start_idx = sample_start_idx
config.input_data.sample_frame_rate = sample_frame_rate
config.validation_data.prompts = [validation_prompt]
config.validation_data.video_length = 8
config.validation_data.width = resolution
config.validation_data.height = resolution
config.validation_data.num_inference_steps = 50
config.validation_data.guidance_scale = guidance_scale
config.input_batch_size = 1
config.seed = seed
config_path = output_dir / 'config.yaml'
with open(config_path, 'w') as f:
OmegaConf.save(config, f)
command = f'accelerate launch test_vid2vid_zero.py --config {config_path}'
subprocess.run(shlex.split(command))
output_video_path = os.path.join(output_dir, "sample-all.mp4")
print(f"video path for gradio: {output_video_path}")
message = 'Running completed!'
print(message)
if remove_gpu_after_running:
space_id = os.getenv('SPACE_ID')
if space_id:
api = HfApi(
token=self.hf_token if self.hf_token else input_token)
api.request_space_hardware(repo_id=space_id,
hardware='cpu-basic')
return output_video_path
|