Spaces:
Runtime error
Runtime error
File size: 4,211 Bytes
16c8067 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
from typing import Iterable, Iterator, List, Tuple
import cv2
import numpy as np
import torch
import torch.nn as nn
from omegaconf import DictConfig
from tqdm import tqdm
from config import hparams as hp
from nota_wav2lip.models.util import count_params, load_model
class Wav2LipInferenceImpl:
def __init__(self, model_name: str, hp_inference_model: DictConfig, device='cpu'):
self.model: nn.Module = load_model(
model_name,
device=device,
**hp_inference_model
)
self.device = device
self._params: str = self._format_param(count_params(self.model))
@property
def params(self):
return self._params
@staticmethod
def _format_param(num_params: int) -> str:
params_in_million = num_params / 1e6
return f"{params_in_million:.1f}M"
@staticmethod
def _reset_batch() -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray], List[List[int]]]:
return [], [], [], []
def get_data_iterator(
self,
audio_iterable: Iterable[np.ndarray],
video_iterable: List[Tuple[np.ndarray, List[int]]]
) -> Iterator[Tuple[np.ndarray, np.ndarray, np.ndarray, List[int]]]:
img_batch, mel_batch, frame_batch, coords_batch = self._reset_batch()
for i, m in enumerate(audio_iterable):
idx = i % len(video_iterable)
_frame_to_save, coords = video_iterable[idx]
frame_to_save = _frame_to_save.copy()
face = frame_to_save[coords[0]:coords[1], coords[2]:coords[3]].copy()
face: np.ndarray = cv2.resize(face, (hp.face.img_size, hp.face.img_size))
img_batch.append(face)
mel_batch.append(m)
frame_batch.append(frame_to_save)
coords_batch.append(coords)
if len(img_batch) >= hp.inference.batch_size:
img_batch = np.asarray(img_batch)
mel_batch = np.asarray(mel_batch)
img_masked = img_batch.copy()
img_masked[:, hp.face.img_size // 2:] = 0
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
yield img_batch, mel_batch, frame_batch, coords_batch
img_batch, mel_batch, frame_batch, coords_batch = self._reset_batch()
if len(img_batch) > 0:
img_batch = np.asarray(img_batch)
mel_batch = np.asarray(mel_batch)
img_masked = img_batch.copy()
img_masked[:, hp.face.img_size // 2:] = 0
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
yield img_batch, mel_batch, frame_batch, coords_batch
@torch.no_grad()
def inference_with_iterator(
self,
audio_iterable: Iterable[np.ndarray],
video_iterable: List[Tuple[np.ndarray, List[int]]]
) -> Iterator[np.ndarray]:
data_iterator = self.get_data_iterator(audio_iterable, video_iterable)
for (img_batch, mel_batch, frames, coords) in \
tqdm(data_iterator, total=int(np.ceil(float(len(audio_iterable)) / hp.inference.batch_size))):
img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(self.device)
mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(self.device)
preds: torch.Tensor = self.forward(mel_batch, img_batch)
preds = preds.cpu().numpy().transpose(0, 2, 3, 1) * 255.
for pred, frame, coord in zip(preds, frames, coords):
y1, y2, x1, x2 = coord
pred = cv2.resize(pred.astype(np.uint8), (x2 - x1, y2 - y1))
frame[y1:y2, x1:x2] = pred
yield frame
@torch.no_grad()
def forward(self, audio_sequences: torch.Tensor, face_sequences: torch.Tensor) -> torch.Tensor:
return self.model(audio_sequences, face_sequences)
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
|