Spaces:
Build error
Build error
saifturzo3
commited on
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +16 -0
- Dockerfile +57 -0
- README.md +152 -6
- Wav2Lip_simplified_V5(offline).ipynb +225 -0
- Wav2Lip_simplified_v4.ipynb +482 -0
- Wav2Lip_simplified_v5.ipynb +640 -0
- app.py +44 -0
- audio.py +136 -0
- checkpoints/README.md +1 -0
- cog.yaml +35 -0
- color_syncnet_train.py +279 -0
- evaluation/README.md +63 -0
- evaluation/gen_videos_from_filelist.py +238 -0
- evaluation/real_videos_inference.py +305 -0
- evaluation/scores_LSE/SyncNetInstance_calc_scores.py +210 -0
- evaluation/scores_LSE/calculate_scores_LRS.py +53 -0
- evaluation/scores_LSE/calculate_scores_real_videos.py +45 -0
- evaluation/scores_LSE/calculate_scores_real_videos.sh +8 -0
- evaluation/test_filelists/README.md +13 -0
- evaluation/test_filelists/ReSyncED/random_pairs.txt +160 -0
- evaluation/test_filelists/ReSyncED/tts_pairs.txt +18 -0
- face_detect.py +55 -0
- face_detection/README.md +1 -0
- face_detection/__init__.py +7 -0
- face_detection/api.py +79 -0
- face_detection/detection/__init__.py +1 -0
- face_detection/detection/core.py +130 -0
- face_detection/detection/sfd/__init__.py +1 -0
- face_detection/detection/sfd/bbox.py +129 -0
- face_detection/detection/sfd/detect.py +112 -0
- face_detection/detection/sfd/net_s3fd.py +129 -0
- face_detection/detection/sfd/sfd_detector.py +60 -0
- face_detection/models.py +261 -0
- face_detection/utils.py +313 -0
- filelists/README.md +1 -0
- hparams.py +101 -0
- hq_wav2lip_train.py +443 -0
- inference.py +323 -0
- models/__init__.py +2 -0
- models/conv.py +44 -0
- models/syncnet.py +66 -0
- models/wav2lip.py +184 -0
- predict.py +144 -0
- preprocess.py +113 -0
- requirements.txt +10 -0
- requirementsCPU.txt +9 -0
- requirements_colab.txt +7 -0
- results/README.md +1 -0
- scripts/download_models.sh +7 -0
- scripts/run-dev.sh +13 -0
.gitignore
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.pkl
|
2 |
+
*.jpg
|
3 |
+
*.mp4
|
4 |
+
*.pth
|
5 |
+
*.pyc
|
6 |
+
__pycache__
|
7 |
+
*.h5
|
8 |
+
*.avi
|
9 |
+
*.wav
|
10 |
+
filelists/*.txt
|
11 |
+
evaluation/test_filelists/lr*.txt
|
12 |
+
*.pyc
|
13 |
+
*.mkv
|
14 |
+
*.gif
|
15 |
+
*.webm
|
16 |
+
*.mp3
|
Dockerfile
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM nvidia/cuda:11.6.2-cudnn8-devel-ubuntu20.04
|
2 |
+
|
3 |
+
ARG DEBIAN_FRONTEND=noninteractive
|
4 |
+
|
5 |
+
# install python via pyenv
|
6 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
7 |
+
make \
|
8 |
+
build-essential \
|
9 |
+
libssl-dev \
|
10 |
+
zlib1g-dev \
|
11 |
+
libbz2-dev \
|
12 |
+
libreadline-dev \
|
13 |
+
libsqlite3-dev \
|
14 |
+
wget \
|
15 |
+
curl \
|
16 |
+
llvm \
|
17 |
+
libncurses5-dev \
|
18 |
+
libncursesw5-dev \
|
19 |
+
xz-utils \
|
20 |
+
tk-dev \
|
21 |
+
libffi-dev \
|
22 |
+
liblzma-dev \
|
23 |
+
git \
|
24 |
+
ca-certificates \
|
25 |
+
libgl1 \
|
26 |
+
&& rm -rf /var/lib/apt/lists/*
|
27 |
+
ENV PATH="/root/.pyenv/shims:/root/.pyenv/bin:$PATH"
|
28 |
+
ARG PYTHON_VERSION=3.8
|
29 |
+
RUN curl -s -S -L https://raw.githubusercontent.com/pyenv/pyenv-installer/master/bin/pyenv-installer | bash && \
|
30 |
+
pyenv install $PYTHON_VERSION && \
|
31 |
+
pyenv global $PYTHON_VERSION
|
32 |
+
|
33 |
+
# install cog
|
34 |
+
RUN pip install cog
|
35 |
+
|
36 |
+
# install deps
|
37 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
38 |
+
ffmpeg libsndfile1 \
|
39 |
+
&& rm -rf /var/lib/apt/lists/*
|
40 |
+
|
41 |
+
# copy to /src
|
42 |
+
ENV WORKDIR /src
|
43 |
+
RUN mkdir -p $WORKDIR
|
44 |
+
WORKDIR $WORKDIR
|
45 |
+
|
46 |
+
# install requirements
|
47 |
+
COPY requirements.txt .
|
48 |
+
RUN pip install -r requirements.txt
|
49 |
+
RUN pip install git+https://github.com/elliottzheng/batch-face.git@master
|
50 |
+
|
51 |
+
# copy sources
|
52 |
+
COPY . .
|
53 |
+
|
54 |
+
ENV PYTHONUNBUFFERED=1
|
55 |
+
|
56 |
+
# run cog
|
57 |
+
CMD python3 -m cog.server.http
|
README.md
CHANGED
@@ -1,12 +1,158 @@
|
|
1 |
---
|
2 |
title: Wav2Lip
|
3 |
-
|
4 |
-
colorFrom: yellow
|
5 |
-
colorTo: indigo
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.21.0
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
-
|
|
|
|
|
|
1 |
---
|
2 |
title: Wav2Lip
|
3 |
+
app_file: app.py
|
|
|
|
|
4 |
sdk: gradio
|
5 |
sdk_version: 4.21.0
|
|
|
|
|
6 |
---
|
7 |
+
# **Wav2Lip**: *Accurately Lip-syncing Videos In The Wild*
|
8 |
+
|
9 |
+
For commercial requests, please contact us at [email protected] or [email protected]. We have an HD model ready that can be used commercially.
|
10 |
+
|
11 |
+
This code is part of the paper: _A Lip Sync Expert Is All You Need for Speech to Lip Generation In the Wild_ published at ACM Multimedia 2020.
|
12 |
+
|
13 |
+
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/a-lip-sync-expert-is-all-you-need-for-speech/lip-sync-on-lrs2)](https://paperswithcode.com/sota/lip-sync-on-lrs2?p=a-lip-sync-expert-is-all-you-need-for-speech)
|
14 |
+
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/a-lip-sync-expert-is-all-you-need-for-speech/lip-sync-on-lrs3)](https://paperswithcode.com/sota/lip-sync-on-lrs3?p=a-lip-sync-expert-is-all-you-need-for-speech)
|
15 |
+
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/a-lip-sync-expert-is-all-you-need-for-speech/lip-sync-on-lrw)](https://paperswithcode.com/sota/lip-sync-on-lrw?p=a-lip-sync-expert-is-all-you-need-for-speech)
|
16 |
+
|
17 |
+
|📑 Original Paper|📰 Project Page|🌀 Demo|⚡ Live Testing|📔 Colab Notebook
|
18 |
+
|:-:|:-:|:-:|:-:|:-:|
|
19 |
+
[Paper](http://arxiv.org/abs/2008.10010) | [Project Page](http://cvit.iiit.ac.in/research/projects/cvit-projects/a-lip-sync-expert-is-all-you-need-for-speech-to-lip-generation-in-the-wild/) | [Demo Video](https://youtu.be/0fXaDCZNOJc) | [Interactive Demo](https://bhaasha.iiit.ac.in/lipsync) | [Colab Notebook](https://colab.research.google.com/drive/1tZpDWXz49W6wDcTprANRGLo2D_EbD5J8?usp=sharing) /[Updated Collab Notebook](https://colab.research.google.com/drive/1IjFW1cLevs6Ouyu4Yht4mnR4yeuMqO7Y#scrollTo=MH1m608OymLH)
|
20 |
+
|
21 |
+
<img src="https://drive.google.com/uc?export=view&id=1Wn0hPmpo4GRbCIJR8Tf20Akzdi1qjjG9"/>
|
22 |
+
|
23 |
+
----------
|
24 |
+
**Highlights**
|
25 |
+
----------
|
26 |
+
- Weights of the visual quality disc has been updated in readme!
|
27 |
+
- Lip-sync videos to any target speech with high accuracy :100:. Try our [interactive demo](https://bhaasha.iiit.ac.in/lipsync).
|
28 |
+
- :sparkles: Works for any identity, voice, and language. Also works for CGI faces and synthetic voices.
|
29 |
+
- Complete training code, inference code, and pretrained models are available :boom:
|
30 |
+
- Or, quick-start with the Google Colab Notebook: [Link](https://colab.research.google.com/drive/1tZpDWXz49W6wDcTprANRGLo2D_EbD5J8?usp=sharing). Checkpoints and samples are available in a Google Drive [folder](https://drive.google.com/drive/folders/1I-0dNLfFOSFwrfqjNa-SXuwaURHE5K4k?usp=sharing) as well. There is also a [tutorial video](https://www.youtube.com/watch?v=Ic0TBhfuOrA) on this, courtesy of [What Make Art](https://www.youtube.com/channel/UCmGXH-jy0o2CuhqtpxbaQgA). Also, thanks to [Eyal Gruss](https://eyalgruss.com), there is a more accessible [Google Colab notebook](https://j.mp/wav2lip) with more useful features. A tutorial collab notebook is present at this [link](https://colab.research.google.com/drive/1IjFW1cLevs6Ouyu4Yht4mnR4yeuMqO7Y#scrollTo=MH1m608OymLH).
|
31 |
+
- :fire: :fire: Several new, reliable evaluation benchmarks and metrics [[`evaluation/` folder of this repo]](https://github.com/Rudrabha/Wav2Lip/tree/master/evaluation) released. Instructions to calculate the metrics reported in the paper are also present.
|
32 |
+
|
33 |
+
--------
|
34 |
+
**Disclaimer**
|
35 |
+
--------
|
36 |
+
All results from this open-source code or our [demo website](https://bhaasha.iiit.ac.in/lipsync) should only be used for research/academic/personal purposes only. As the models are trained on the <a href="http://www.robots.ox.ac.uk/~vgg/data/lip_reading/lrs2.html">LRS2 dataset</a>, any form of commercial use is strictly prohibhited. For commercial requests please contact us directly!
|
37 |
+
|
38 |
+
Prerequisites
|
39 |
+
-------------
|
40 |
+
- `Python 3.6`
|
41 |
+
- ffmpeg: `sudo apt-get install ffmpeg`
|
42 |
+
- Install necessary packages using `pip install -r requirements.txt`. Alternatively, instructions for using a docker image is provided [here](https://gist.github.com/xenogenesi/e62d3d13dadbc164124c830e9c453668). Have a look at [this comment](https://github.com/Rudrabha/Wav2Lip/issues/131#issuecomment-725478562) and comment on [the gist](https://gist.github.com/xenogenesi/e62d3d13dadbc164124c830e9c453668) if you encounter any issues.
|
43 |
+
- Face detection [pre-trained model](https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth) should be downloaded to `face_detection/detection/sfd/s3fd.pth`. Alternative [link](https://iiitaphyd-my.sharepoint.com/:u:/g/personal/prajwal_k_research_iiit_ac_in/EZsy6qWuivtDnANIG73iHjIBjMSoojcIV0NULXV-yiuiIg?e=qTasa8) if the above does not work.
|
44 |
+
|
45 |
+
Getting the weights
|
46 |
+
----------
|
47 |
+
| Model | Description | Link to the model |
|
48 |
+
| :-------------: | :---------------: | :---------------: |
|
49 |
+
| Wav2Lip | Highly accurate lip-sync | [Link](https://iiitaphyd-my.sharepoint.com/:u:/g/personal/radrabha_m_research_iiit_ac_in/Eb3LEzbfuKlJiR600lQWRxgBIY27JZg80f7V9jtMfbNDaQ?e=TBFBVW) |
|
50 |
+
| Wav2Lip + GAN | Slightly inferior lip-sync, but better visual quality | [Link](https://iiitaphyd-my.sharepoint.com/:u:/g/personal/radrabha_m_research_iiit_ac_in/EdjI7bZlgApMqsVoEUUXpLsBxqXbn5z8VTmoxp55YNDcIA?e=n9ljGW) |
|
51 |
+
| Expert Discriminator | Weights of the expert discriminator | [Link](https://iiitaphyd-my.sharepoint.com/:u:/g/personal/radrabha_m_research_iiit_ac_in/EQRvmiZg-HRAjvI6zqN9eTEBP74KefynCwPWVmF57l-AYA?e=ZRPHKP) |
|
52 |
+
| Visual Quality Discriminator | Weights of the visual disc trained in a GAN setup | [Link](https://iiitaphyd-my.sharepoint.com/:u:/g/personal/radrabha_m_research_iiit_ac_in/EQVqH88dTm1HjlK11eNba5gBbn15WMS0B0EZbDBttqrqkg?e=ic0ljo) |
|
53 |
+
|
54 |
+
Lip-syncing videos using the pre-trained models (Inference)
|
55 |
+
-------
|
56 |
+
You can lip-sync any video to any audio:
|
57 |
+
```bash
|
58 |
+
python inference.py --checkpoint_path <ckpt> --face <video.mp4> --audio <an-audio-source>
|
59 |
+
```
|
60 |
+
The result is saved (by default) in `results/result_voice.mp4`. You can specify it as an argument, similar to several other available options. The audio source can be any file supported by `FFMPEG` containing audio data: `*.wav`, `*.mp3` or even a video file, from which the code will automatically extract the audio.
|
61 |
+
|
62 |
+
##### Tips for better results:
|
63 |
+
- Experiment with the `--pads` argument to adjust the detected face bounding box. Often leads to improved results. You might need to increase the bottom padding to include the chin region. E.g. `--pads 0 20 0 0`.
|
64 |
+
- If you see the mouth position dislocated or some weird artifacts such as two mouths, then it can be because of over-smoothing the face detections. Use the `--nosmooth` argument and give another try.
|
65 |
+
- Experiment with the `--resize_factor` argument, to get a lower resolution video. Why? The models are trained on faces which were at a lower resolution. You might get better, visually pleasing results for 720p videos than for 1080p videos (in many cases, the latter works well too).
|
66 |
+
- The Wav2Lip model without GAN usually needs more experimenting with the above two to get the most ideal results, and sometimes, can give you a better result as well.
|
67 |
+
|
68 |
+
Preparing LRS2 for training
|
69 |
+
----------
|
70 |
+
Our models are trained on LRS2. See [here](#training-on-datasets-other-than-lrs2) for a few suggestions regarding training on other datasets.
|
71 |
+
##### LRS2 dataset folder structure
|
72 |
+
|
73 |
+
```
|
74 |
+
data_root (mvlrs_v1)
|
75 |
+
├── main, pretrain (we use only main folder in this work)
|
76 |
+
| ├── list of folders
|
77 |
+
| │ ├── five-digit numbered video IDs ending with (.mp4)
|
78 |
+
```
|
79 |
+
|
80 |
+
Place the LRS2 filelists (train, val, test) `.txt` files in the `filelists/` folder.
|
81 |
+
|
82 |
+
##### Preprocess the dataset for fast training
|
83 |
+
|
84 |
+
```bash
|
85 |
+
python preprocess.py --data_root data_root/main --preprocessed_root lrs2_preprocessed/
|
86 |
+
```
|
87 |
+
Additional options like `batch_size` and number of GPUs to use in parallel to use can also be set.
|
88 |
+
|
89 |
+
##### Preprocessed LRS2 folder structure
|
90 |
+
```
|
91 |
+
preprocessed_root (lrs2_preprocessed)
|
92 |
+
├── list of folders
|
93 |
+
| ├── Folders with five-digit numbered video IDs
|
94 |
+
| │ ├── *.jpg
|
95 |
+
| │ ├── audio.wav
|
96 |
+
```
|
97 |
+
|
98 |
+
Train!
|
99 |
+
----------
|
100 |
+
There are two major steps: (i) Train the expert lip-sync discriminator, (ii) Train the Wav2Lip model(s).
|
101 |
+
|
102 |
+
##### Training the expert discriminator
|
103 |
+
You can download [the pre-trained weights](#getting-the-weights) if you want to skip this step. To train it:
|
104 |
+
```bash
|
105 |
+
python color_syncnet_train.py --data_root lrs2_preprocessed/ --checkpoint_dir <folder_to_save_checkpoints>
|
106 |
+
```
|
107 |
+
##### Training the Wav2Lip models
|
108 |
+
You can either train the model without the additional visual quality disriminator (< 1 day of training) or use the discriminator (~2 days). For the former, run:
|
109 |
+
```bash
|
110 |
+
python wav2lip_train.py --data_root lrs2_preprocessed/ --checkpoint_dir <folder_to_save_checkpoints> --syncnet_checkpoint_path <path_to_expert_disc_checkpoint>
|
111 |
+
```
|
112 |
+
|
113 |
+
To train with the visual quality discriminator, you should run `hq_wav2lip_train.py` instead. The arguments for both the files are similar. In both the cases, you can resume training as well. Look at `python wav2lip_train.py --help` for more details. You can also set additional less commonly-used hyper-parameters at the bottom of the `hparams.py` file.
|
114 |
+
|
115 |
+
Training on datasets other than LRS2
|
116 |
+
------------------------------------
|
117 |
+
Training on other datasets might require modifications to the code. Please read the following before you raise an issue:
|
118 |
+
|
119 |
+
- You might not get good results by training/fine-tuning on a few minutes of a single speaker. This is a separate research problem, to which we do not have a solution yet. Thus, we would most likely not be able to resolve your issue.
|
120 |
+
- You must train the expert discriminator for your own dataset before training Wav2Lip.
|
121 |
+
- If it is your own dataset downloaded from the web, in most cases, needs to be sync-corrected.
|
122 |
+
- Be mindful of the FPS of the videos of your dataset. Changes to FPS would need significant code changes.
|
123 |
+
- The expert discriminator's eval loss should go down to ~0.25 and the Wav2Lip eval sync loss should go down to ~0.2 to get good results.
|
124 |
+
|
125 |
+
When raising an issue on this topic, please let us know that you are aware of all these points.
|
126 |
+
|
127 |
+
We have an HD model trained on a dataset allowing commercial usage. The size of the generated face will be 192 x 288 in our new model.
|
128 |
+
|
129 |
+
Evaluation
|
130 |
+
----------
|
131 |
+
Please check the `evaluation/` folder for the instructions.
|
132 |
+
|
133 |
+
License and Citation
|
134 |
+
----------
|
135 |
+
Theis repository can only be used for personal/research/non-commercial purposes. However, for commercial requests, please contact us directly at [email protected] or [email protected]. We have an HD model trained on a dataset allowing commercial usage. The size of the generated face will be 192 x 288 in our new model. Please cite the following paper if you use this repository:
|
136 |
+
```
|
137 |
+
@inproceedings{10.1145/3394171.3413532,
|
138 |
+
author = {Prajwal, K R and Mukhopadhyay, Rudrabha and Namboodiri, Vinay P. and Jawahar, C.V.},
|
139 |
+
title = {A Lip Sync Expert Is All You Need for Speech to Lip Generation In the Wild},
|
140 |
+
year = {2020},
|
141 |
+
isbn = {9781450379885},
|
142 |
+
publisher = {Association for Computing Machinery},
|
143 |
+
address = {New York, NY, USA},
|
144 |
+
url = {https://doi.org/10.1145/3394171.3413532},
|
145 |
+
doi = {10.1145/3394171.3413532},
|
146 |
+
booktitle = {Proceedings of the 28th ACM International Conference on Multimedia},
|
147 |
+
pages = {484–492},
|
148 |
+
numpages = {9},
|
149 |
+
keywords = {lip sync, talking face generation, video generation},
|
150 |
+
location = {Seattle, WA, USA},
|
151 |
+
series = {MM '20}
|
152 |
+
}
|
153 |
+
```
|
154 |
+
|
155 |
|
156 |
+
Acknowledgements
|
157 |
+
----------
|
158 |
+
Parts of the code structure is inspired by this [TTS repository](https://github.com/r9y9/deepvoice3_pytorch). We thank the author for this wonderful code. The code for Face Detection has been taken from the [face_alignment](https://github.com/1adrianb/face-alignment) repository. We thank the authors for releasing their code and models. We thank [zabique](https://github.com/zabique) for the tutorial collab notebook.
|
Wav2Lip_simplified_V5(offline).ipynb
ADDED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"id": "f1e90f25",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": [
|
10 |
+
"# Step 1: Install dependency\n",
|
11 |
+
"!pip install ffmpeg-python\n",
|
12 |
+
"\n",
|
13 |
+
"# Step 2: Clone the Wav2Lip repository\n",
|
14 |
+
"!git clone https://github.com/justinjohn0306/Wav2Lip\n",
|
15 |
+
"\n",
|
16 |
+
"# Step 3: Download pretrained model\n",
|
17 |
+
"import requests\n",
|
18 |
+
"url = \"https://iiitaphyd-my.sharepoint.com/personal/radrabha_m_research_iiit_ac_in/_layouts/15/download.aspx?share=EdjI7bZlgApMqsVoEUUXpLsBxqXbn5z8VTmoxp55YNDcIA\"\n",
|
19 |
+
"response = requests.get(url)\n",
|
20 |
+
"\n",
|
21 |
+
"with open(\"Wav2Lip/checkpoints/wav2lip_gan.pth\", \"wb\") as f:\n",
|
22 |
+
" f.write(response.content)\n",
|
23 |
+
" \n",
|
24 |
+
"# Step 4: Install the required dependencies for Wav2Lip\n",
|
25 |
+
"!cd Wav2Lip && pip install -r requirements.txt\n",
|
26 |
+
"!pip install pyaudio\n",
|
27 |
+
"\n",
|
28 |
+
"\n",
|
29 |
+
"# Step 5: Download pretrained model for face detection\n",
|
30 |
+
"url = \"https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth\"\n",
|
31 |
+
"response = requests.get(url)\n",
|
32 |
+
"\n",
|
33 |
+
"with open(\"Wav2Lip/face_detection/detection/sfd/s3fd.pth\", \"wb\") as f:\n",
|
34 |
+
" f.write(response.content)"
|
35 |
+
]
|
36 |
+
},
|
37 |
+
{
|
38 |
+
"cell_type": "code",
|
39 |
+
"execution_count": null,
|
40 |
+
"id": "8e86c988",
|
41 |
+
"metadata": {},
|
42 |
+
"outputs": [],
|
43 |
+
"source": [
|
44 |
+
"import os\n",
|
45 |
+
"import subprocess\n",
|
46 |
+
"from urllib import parse as urlparse\n",
|
47 |
+
"\n",
|
48 |
+
"# Step 1: Install yt-dlp\n",
|
49 |
+
"subprocess.run(['pip', 'install', 'yt-dlp'])\n",
|
50 |
+
"\n",
|
51 |
+
"# Step 2: Define YouTube URL and Video ID\n",
|
52 |
+
"YOUTUBE_URL = 'https://www.youtube.com/watch?v=vAnWYLTdvfY'\n",
|
53 |
+
"url_data = urlparse.urlparse(YOUTUBE_URL)\n",
|
54 |
+
"query = urlparse.parse_qs(url_data.query)\n",
|
55 |
+
"YOUTUBE_ID = query[\"v\"][0]\n",
|
56 |
+
"\n",
|
57 |
+
"# Remove previous input video\n",
|
58 |
+
"if os.path.isfile('input_vid.mp4'):\n",
|
59 |
+
" os.remove('input_vid.mp4')\n",
|
60 |
+
"\n",
|
61 |
+
"# Trim video (start, end) seconds\n",
|
62 |
+
"start = 35\n",
|
63 |
+
"end = 62\n",
|
64 |
+
"interval = end - start\n",
|
65 |
+
"\n",
|
66 |
+
"# Step 3: Download and trim the YouTube video\n",
|
67 |
+
"subprocess.run(['yt-dlp', '-f', 'bestvideo[ext=mp4]', '--output', \"youtube.%(ext)s\", f'https://www.youtube.com/watch?v={YOUTUBE_ID}'])\n",
|
68 |
+
"\n",
|
69 |
+
"# Cut the video using FFmpeg\n",
|
70 |
+
"subprocess.run(['ffmpeg', '-y', '-i', 'youtube.mp4', '-ss', str(start), '-t', str(interval), '-async', '1', 'input_vid.mp4'])\n",
|
71 |
+
"\n",
|
72 |
+
"# Display video.\n",
|
73 |
+
"from IPython.display import HTML\n",
|
74 |
+
"from base64 import b64encode\n",
|
75 |
+
"\n",
|
76 |
+
"def show_video(path):\n",
|
77 |
+
" mp4 = open(path, 'rb').read()\n",
|
78 |
+
" data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n",
|
79 |
+
" return HTML(f\"\"\"<video width=600 controls><source src=\"{data_url}\"></video>\"\"\")\n",
|
80 |
+
"\n",
|
81 |
+
"# Preview the trimmed video\n",
|
82 |
+
"show_video('input_vid.mp4')"
|
83 |
+
]
|
84 |
+
},
|
85 |
+
{
|
86 |
+
"cell_type": "code",
|
87 |
+
"execution_count": null,
|
88 |
+
"id": "7da8e818",
|
89 |
+
"metadata": {
|
90 |
+
"scrolled": false
|
91 |
+
},
|
92 |
+
"outputs": [],
|
93 |
+
"source": [
|
94 |
+
"import os\n",
|
95 |
+
"from IPython.display import Audio\n",
|
96 |
+
"from IPython.core.display import display\n",
|
97 |
+
"\n",
|
98 |
+
"upload_method = 'Path' # Change this to 'Record' or 'Path'\n",
|
99 |
+
"\n",
|
100 |
+
"# Remove previous input audio\n",
|
101 |
+
"if os.path.isfile('input_audio.wav'):\n",
|
102 |
+
" os.remove('input_audio.wav')\n",
|
103 |
+
"\n",
|
104 |
+
"def display_audio():\n",
|
105 |
+
" display(Audio('input_audio.wav'))\n",
|
106 |
+
"\n",
|
107 |
+
"if upload_method == 'Record':\n",
|
108 |
+
" import pyaudio\n",
|
109 |
+
" import wave\n",
|
110 |
+
"\n",
|
111 |
+
" CHUNK = 1024\n",
|
112 |
+
" FORMAT = pyaudio.paInt16\n",
|
113 |
+
" CHANNELS = 1\n",
|
114 |
+
" RATE = 16000\n",
|
115 |
+
" RECORD_SECONDS = 5\n",
|
116 |
+
" WAVE_OUTPUT_FILENAME = \"input_audio.wav\"\n",
|
117 |
+
"\n",
|
118 |
+
" p = pyaudio.PyAudio()\n",
|
119 |
+
"\n",
|
120 |
+
" stream = p.open(format=FORMAT,\n",
|
121 |
+
" channels=CHANNELS,\n",
|
122 |
+
" rate=RATE,\n",
|
123 |
+
" input=True,\n",
|
124 |
+
" frames_per_buffer=CHUNK)\n",
|
125 |
+
"\n",
|
126 |
+
" print(\"Recording...\")\n",
|
127 |
+
"\n",
|
128 |
+
" frames = []\n",
|
129 |
+
"\n",
|
130 |
+
" for i in range(0, int(RATE / CHUNK * RECORD_SECONDS)):\n",
|
131 |
+
" data = stream.read(CHUNK)\n",
|
132 |
+
" frames.append(data)\n",
|
133 |
+
"\n",
|
134 |
+
" print(\"Finished recording.\")\n",
|
135 |
+
"\n",
|
136 |
+
" stream.stop_stream()\n",
|
137 |
+
" stream.close()\n",
|
138 |
+
" p.terminate()\n",
|
139 |
+
"\n",
|
140 |
+
" wf = wave.open(WAVE_OUTPUT_FILENAME, 'wb')\n",
|
141 |
+
" wf.setnchannels(CHANNELS)\n",
|
142 |
+
" wf.setsampwidth(p.get_sample_size(FORMAT))\n",
|
143 |
+
" wf.setframerate(RATE)\n",
|
144 |
+
" wf.writeframes(b''.join(frames))\n",
|
145 |
+
" wf.close()\n",
|
146 |
+
"\n",
|
147 |
+
" display_audio()\n",
|
148 |
+
"\n",
|
149 |
+
"elif upload_method == 'Path':\n",
|
150 |
+
" # Add the full path to your audio\n",
|
151 |
+
" PATH_TO_YOUR_AUDIO = 'C:/Users/justi/OneDrive/Desktop/wav2lip/Wav2Lip/input_audio.wav'\n",
|
152 |
+
"\n",
|
153 |
+
" # Load audio with specified sampling rate\n",
|
154 |
+
" import librosa\n",
|
155 |
+
" audio, sr = librosa.load(PATH_TO_YOUR_AUDIO, sr=None)\n",
|
156 |
+
"\n",
|
157 |
+
" # Save audio with specified sampling rate\n",
|
158 |
+
" import soundfile as sf\n",
|
159 |
+
" sf.write('input_audio.wav', audio, sr, format='wav')\n",
|
160 |
+
"\n",
|
161 |
+
" display_audio()"
|
162 |
+
]
|
163 |
+
},
|
164 |
+
{
|
165 |
+
"cell_type": "code",
|
166 |
+
"execution_count": null,
|
167 |
+
"id": "63289945",
|
168 |
+
"metadata": {
|
169 |
+
"scrolled": true
|
170 |
+
},
|
171 |
+
"outputs": [],
|
172 |
+
"source": [
|
173 |
+
"# Define the parameters for the Wav2Lip model\n",
|
174 |
+
"pad_top = 0\n",
|
175 |
+
"pad_bottom = 10\n",
|
176 |
+
"pad_left = 0\n",
|
177 |
+
"pad_right = 0\n",
|
178 |
+
"rescaleFactor = 1\n",
|
179 |
+
"nosmooth = False\n",
|
180 |
+
"\n",
|
181 |
+
"# Set the path to the Wav2Lip model and input files\n",
|
182 |
+
"checkpoint_path = \"checkpoints/wav2lip_gan.pth\"\n",
|
183 |
+
"input_face = \"input_vid.mp4\"\n",
|
184 |
+
"input_audio = \"input_audio.wav\"\n",
|
185 |
+
"\n",
|
186 |
+
"# Run the Wav2Lip model\n",
|
187 |
+
"!cd Wav2Lip && python inference.py --checkpoint_path {checkpoint_path} --face {input_face} --audio {input_audio} --pads {pad_top} {pad_bottom} {pad_left} {pad_right} --resize_factor {rescaleFactor} {\"--nosmooth\" if nosmooth else \"\"}\n",
|
188 |
+
"\n",
|
189 |
+
"# Preview the output video\n",
|
190 |
+
"print(\"Final Video Preview\")\n",
|
191 |
+
"print(\"Find the output video at\", 'Wav2Lip/results/result_voice.mp4')\n",
|
192 |
+
"show_video('Wav2Lip/results/result_voice.mp4')"
|
193 |
+
]
|
194 |
+
},
|
195 |
+
{
|
196 |
+
"cell_type": "code",
|
197 |
+
"execution_count": null,
|
198 |
+
"id": "3fbafa56",
|
199 |
+
"metadata": {},
|
200 |
+
"outputs": [],
|
201 |
+
"source": []
|
202 |
+
}
|
203 |
+
],
|
204 |
+
"metadata": {
|
205 |
+
"kernelspec": {
|
206 |
+
"display_name": "Python 3 (ipykernel)",
|
207 |
+
"language": "python",
|
208 |
+
"name": "python3"
|
209 |
+
},
|
210 |
+
"language_info": {
|
211 |
+
"codemirror_mode": {
|
212 |
+
"name": "ipython",
|
213 |
+
"version": 3
|
214 |
+
},
|
215 |
+
"file_extension": ".py",
|
216 |
+
"mimetype": "text/x-python",
|
217 |
+
"name": "python",
|
218 |
+
"nbconvert_exporter": "python",
|
219 |
+
"pygments_lexer": "ipython3",
|
220 |
+
"version": "3.10.11"
|
221 |
+
}
|
222 |
+
},
|
223 |
+
"nbformat": 4,
|
224 |
+
"nbformat_minor": 5
|
225 |
+
}
|
Wav2Lip_simplified_v4.ipynb
ADDED
@@ -0,0 +1,482 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"nbformat": 4,
|
3 |
+
"nbformat_minor": 0,
|
4 |
+
"metadata": {
|
5 |
+
"accelerator": "GPU",
|
6 |
+
"colab": {
|
7 |
+
"provenance": []
|
8 |
+
},
|
9 |
+
"kernelspec": {
|
10 |
+
"display_name": "Python 3",
|
11 |
+
"name": "python3"
|
12 |
+
}
|
13 |
+
},
|
14 |
+
"cells": [
|
15 |
+
{
|
16 |
+
"cell_type": "markdown",
|
17 |
+
"source": [
|
18 |
+
"**Fixes by: [justinjohn-03](https://github.com/justinjohn0306)**"
|
19 |
+
],
|
20 |
+
"metadata": {
|
21 |
+
"id": "9Uyk6DCBGHuW"
|
22 |
+
}
|
23 |
+
},
|
24 |
+
{
|
25 |
+
"cell_type": "markdown",
|
26 |
+
"metadata": {
|
27 |
+
"id": "U1xFNFU58_2j"
|
28 |
+
},
|
29 |
+
"source": [
|
30 |
+
"## Goal: Make anyone speak anything (LipSync)\n",
|
31 |
+
"\n",
|
32 |
+
"* Github: https://github.com/Rudrabha/Wav2Lip\n",
|
33 |
+
"* Paper: https://arxiv.org/abs/2008.10010\n",
|
34 |
+
"*Original notebook: https://colab.research.google.com/drive/1tZpDWXz49W6wDcTprANRGLo2D_EbD5J8?usp=sharing\n",
|
35 |
+
"\n",
|
36 |
+
"\n",
|
37 |
+
"\n"
|
38 |
+
]
|
39 |
+
},
|
40 |
+
{
|
41 |
+
"cell_type": "code",
|
42 |
+
"metadata": {
|
43 |
+
"id": "Qgo-oaI3JU2u",
|
44 |
+
"cellView": "form"
|
45 |
+
},
|
46 |
+
"source": [
|
47 |
+
"#@title <h1>Step1: Setup Wav2Lip</h1>\n",
|
48 |
+
"#@markdown * Install dependency\n",
|
49 |
+
"#@markdown * Download pretrained model\n",
|
50 |
+
"!rm -rf /content/sample_data\n",
|
51 |
+
"!mkdir /content/sample_data\n",
|
52 |
+
"\n",
|
53 |
+
"!git clone https://github.com/zabique/Wav2Lip\n",
|
54 |
+
"\n",
|
55 |
+
"#download the pretrained model\n",
|
56 |
+
"!wget 'https://iiitaphyd-my.sharepoint.com/personal/radrabha_m_research_iiit_ac_in/_layouts/15/download.aspx?share=EdjI7bZlgApMqsVoEUUXpLsBxqXbn5z8VTmoxp55YNDcIA' -O '/content/Wav2Lip/checkpoints/wav2lip_gan.pth'\n",
|
57 |
+
"a = !pip install https://raw.githubusercontent.com/AwaleSajil/ghc/master/ghc-1.0-py3-none-any.whl\n",
|
58 |
+
"\n",
|
59 |
+
"# !pip uninstall tensorflow tensorflow-gpu\n",
|
60 |
+
"!cd Wav2Lip && pip install -r requirements.txt\n",
|
61 |
+
"\n",
|
62 |
+
"#download pretrained model for face detection\n",
|
63 |
+
"!wget \"https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth\" -O \"/content/Wav2Lip/face_detection/detection/sfd/s3fd.pth\"\n",
|
64 |
+
"\n",
|
65 |
+
"!pip install -q youtube-dl\n",
|
66 |
+
"!pip install ffmpeg-python\n",
|
67 |
+
"!pip install librosa==0.9.1\n",
|
68 |
+
"\n",
|
69 |
+
"#this code for recording audio\n",
|
70 |
+
"\"\"\"\n",
|
71 |
+
"To write this piece of code I took inspiration/code from a lot of places.\n",
|
72 |
+
"It was late night, so I'm not sure how much I created or just copied o.O\n",
|
73 |
+
"Here are some of the possible references:\n",
|
74 |
+
"https://blog.addpipe.com/recording-audio-in-the-browser-using-pure-html5-and-minimal-javascript/\n",
|
75 |
+
"https://stackoverflow.com/a/18650249\n",
|
76 |
+
"https://hacks.mozilla.org/2014/06/easy-audio-capture-with-the-mediarecorder-api/\n",
|
77 |
+
"https://air.ghost.io/recording-to-an-audio-file-using-html5-and-js/\n",
|
78 |
+
"https://stackoverflow.com/a/49019356\n",
|
79 |
+
"\"\"\"\n",
|
80 |
+
"from IPython.display import HTML, Audio\n",
|
81 |
+
"from google.colab.output import eval_js\n",
|
82 |
+
"from base64 import b64decode\n",
|
83 |
+
"import numpy as np\n",
|
84 |
+
"from scipy.io.wavfile import read as wav_read\n",
|
85 |
+
"import io\n",
|
86 |
+
"import ffmpeg\n",
|
87 |
+
"\n",
|
88 |
+
"AUDIO_HTML = \"\"\"\n",
|
89 |
+
"<script>\n",
|
90 |
+
"var my_div = document.createElement(\"DIV\");\n",
|
91 |
+
"var my_p = document.createElement(\"P\");\n",
|
92 |
+
"var my_btn = document.createElement(\"BUTTON\");\n",
|
93 |
+
"var t = document.createTextNode(\"Press to start recording\");\n",
|
94 |
+
"\n",
|
95 |
+
"my_btn.appendChild(t);\n",
|
96 |
+
"//my_p.appendChild(my_btn);\n",
|
97 |
+
"my_div.appendChild(my_btn);\n",
|
98 |
+
"document.body.appendChild(my_div);\n",
|
99 |
+
"\n",
|
100 |
+
"var base64data = 0;\n",
|
101 |
+
"var reader;\n",
|
102 |
+
"var recorder, gumStream;\n",
|
103 |
+
"var recordButton = my_btn;\n",
|
104 |
+
"\n",
|
105 |
+
"var handleSuccess = function(stream) {\n",
|
106 |
+
" gumStream = stream;\n",
|
107 |
+
" var options = {\n",
|
108 |
+
" //bitsPerSecond: 8000, //chrome seems to ignore, always 48k\n",
|
109 |
+
" mimeType : 'audio/webm;codecs=opus'\n",
|
110 |
+
" //mimeType : 'audio/webm;codecs=pcm'\n",
|
111 |
+
" }; \n",
|
112 |
+
" //recorder = new MediaRecorder(stream, options);\n",
|
113 |
+
" recorder = new MediaRecorder(stream);\n",
|
114 |
+
" recorder.ondataavailable = function(e) { \n",
|
115 |
+
" var url = URL.createObjectURL(e.data);\n",
|
116 |
+
" var preview = document.createElement('audio');\n",
|
117 |
+
" preview.controls = true;\n",
|
118 |
+
" preview.src = url;\n",
|
119 |
+
" document.body.appendChild(preview);\n",
|
120 |
+
"\n",
|
121 |
+
" reader = new FileReader();\n",
|
122 |
+
" reader.readAsDataURL(e.data); \n",
|
123 |
+
" reader.onloadend = function() {\n",
|
124 |
+
" base64data = reader.result;\n",
|
125 |
+
" //console.log(\"Inside FileReader:\" + base64data);\n",
|
126 |
+
" }\n",
|
127 |
+
" };\n",
|
128 |
+
" recorder.start();\n",
|
129 |
+
" };\n",
|
130 |
+
"\n",
|
131 |
+
"recordButton.innerText = \"Recording... press to stop\";\n",
|
132 |
+
"\n",
|
133 |
+
"navigator.mediaDevices.getUserMedia({audio: true}).then(handleSuccess);\n",
|
134 |
+
"\n",
|
135 |
+
"\n",
|
136 |
+
"function toggleRecording() {\n",
|
137 |
+
" if (recorder && recorder.state == \"recording\") {\n",
|
138 |
+
" recorder.stop();\n",
|
139 |
+
" gumStream.getAudioTracks()[0].stop();\n",
|
140 |
+
" recordButton.innerText = \"Saving the recording... pls wait!\"\n",
|
141 |
+
" }\n",
|
142 |
+
"}\n",
|
143 |
+
"\n",
|
144 |
+
"// https://stackoverflow.com/a/951057\n",
|
145 |
+
"function sleep(ms) {\n",
|
146 |
+
" return new Promise(resolve => setTimeout(resolve, ms));\n",
|
147 |
+
"}\n",
|
148 |
+
"\n",
|
149 |
+
"var data = new Promise(resolve=>{\n",
|
150 |
+
"//recordButton.addEventListener(\"click\", toggleRecording);\n",
|
151 |
+
"recordButton.onclick = ()=>{\n",
|
152 |
+
"toggleRecording()\n",
|
153 |
+
"\n",
|
154 |
+
"sleep(2000).then(() => {\n",
|
155 |
+
" // wait 2000ms for the data to be available...\n",
|
156 |
+
" // ideally this should use something like await...\n",
|
157 |
+
" //console.log(\"Inside data:\" + base64data)\n",
|
158 |
+
" resolve(base64data.toString())\n",
|
159 |
+
"\n",
|
160 |
+
"});\n",
|
161 |
+
"\n",
|
162 |
+
"}\n",
|
163 |
+
"});\n",
|
164 |
+
" \n",
|
165 |
+
"</script>\n",
|
166 |
+
"\"\"\"\n",
|
167 |
+
"\n",
|
168 |
+
"%cd /\n",
|
169 |
+
"from ghc.l_ghc_cf import l_ghc_cf\n",
|
170 |
+
"%cd content\n",
|
171 |
+
"\n",
|
172 |
+
"def get_audio():\n",
|
173 |
+
" display(HTML(AUDIO_HTML))\n",
|
174 |
+
" data = eval_js(\"data\")\n",
|
175 |
+
" binary = b64decode(data.split(',')[1])\n",
|
176 |
+
" \n",
|
177 |
+
" process = (ffmpeg\n",
|
178 |
+
" .input('pipe:0')\n",
|
179 |
+
" .output('pipe:1', format='wav')\n",
|
180 |
+
" .run_async(pipe_stdin=True, pipe_stdout=True, pipe_stderr=True, quiet=True, overwrite_output=True)\n",
|
181 |
+
" )\n",
|
182 |
+
" output, err = process.communicate(input=binary)\n",
|
183 |
+
" \n",
|
184 |
+
" riff_chunk_size = len(output) - 8\n",
|
185 |
+
" # Break up the chunk size into four bytes, held in b.\n",
|
186 |
+
" q = riff_chunk_size\n",
|
187 |
+
" b = []\n",
|
188 |
+
" for i in range(4):\n",
|
189 |
+
" q, r = divmod(q, 256)\n",
|
190 |
+
" b.append(r)\n",
|
191 |
+
"\n",
|
192 |
+
" # Replace bytes 4:8 in proc.stdout with the actual size of the RIFF chunk.\n",
|
193 |
+
" riff = output[:4] + bytes(b) + output[8:]\n",
|
194 |
+
"\n",
|
195 |
+
" sr, audio = wav_read(io.BytesIO(riff))\n",
|
196 |
+
"\n",
|
197 |
+
" return audio, sr\n",
|
198 |
+
"\n",
|
199 |
+
"\n",
|
200 |
+
"from IPython.display import HTML\n",
|
201 |
+
"from base64 import b64encode\n",
|
202 |
+
"def showVideo(path):\n",
|
203 |
+
" mp4 = open(str(path),'rb').read()\n",
|
204 |
+
" data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n",
|
205 |
+
" return HTML(\"\"\"\n",
|
206 |
+
" <video width=700 controls>\n",
|
207 |
+
" <source src=\"%s\" type=\"video/mp4\">\n",
|
208 |
+
" </video>\n",
|
209 |
+
" \"\"\" % data_url)\n",
|
210 |
+
"\n",
|
211 |
+
"from IPython.display import clear_output"
|
212 |
+
],
|
213 |
+
"execution_count": null,
|
214 |
+
"outputs": []
|
215 |
+
},
|
216 |
+
{
|
217 |
+
"cell_type": "markdown",
|
218 |
+
"metadata": {
|
219 |
+
"id": "SEdy6PWDXMRL"
|
220 |
+
},
|
221 |
+
"source": [
|
222 |
+
"# LipSync Youtube Video"
|
223 |
+
]
|
224 |
+
},
|
225 |
+
{
|
226 |
+
"cell_type": "code",
|
227 |
+
"metadata": {
|
228 |
+
"id": "QI4kcm8QEeGZ",
|
229 |
+
"cellView": "form"
|
230 |
+
},
|
231 |
+
"source": [
|
232 |
+
"#@title STEP2: Select a Youtube Video\n",
|
233 |
+
"# Install yt-dlp\n",
|
234 |
+
"!pip install yt-dlp\n",
|
235 |
+
"\n",
|
236 |
+
"#@markdown ### Find YouTube video ID from URL\n",
|
237 |
+
"from urllib import parse as urlparse\n",
|
238 |
+
"YOUTUBE_URL = 'https://www.youtube.com/watch?v=vAnWYLTdvfY' #@param {type:\"string\"}\n",
|
239 |
+
"url_data = urlparse.urlparse(YOUTUBE_URL)\n",
|
240 |
+
"query = urlparse.parse_qs(url_data.query)\n",
|
241 |
+
"YOUTUBE_ID = query[\"v\"][0]\n",
|
242 |
+
"\n",
|
243 |
+
"#@markdown ### Trim the video (start, end) seconds\n",
|
244 |
+
"start = 35 #@param {type:\"integer\"}\n",
|
245 |
+
"end = 62 #@param {type:\"integer\"}\n",
|
246 |
+
"interval = end - start\n",
|
247 |
+
"\n",
|
248 |
+
"# Download the YouTube video using yt-dlp\n",
|
249 |
+
"!yt-dlp -f 'bestvideo[ext=mp4]' --output \"youtube.%(ext)s\" https://www.youtube.com/watch?v=$YOUTUBE_ID\n",
|
250 |
+
"\n",
|
251 |
+
"# Cut the video using FFmpeg\n",
|
252 |
+
"!ffmpeg -y -i youtube.mp4 -ss {start} -t {interval} -async 1 /content/sample_data/input_vid.mp4\n",
|
253 |
+
"\n",
|
254 |
+
"# Preview the trimmed video\n",
|
255 |
+
"from IPython.display import HTML\n",
|
256 |
+
"from base64 import b64encode\n",
|
257 |
+
"mp4 = open('/content/sample_data/input_vid.mp4','rb').read()\n",
|
258 |
+
"data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n",
|
259 |
+
"HTML(f\"\"\"<video width=600 controls><source src=\"{data_url}\"></video>\"\"\")\n",
|
260 |
+
"\n"
|
261 |
+
],
|
262 |
+
"execution_count": null,
|
263 |
+
"outputs": []
|
264 |
+
},
|
265 |
+
{
|
266 |
+
"cell_type": "code",
|
267 |
+
"metadata": {
|
268 |
+
"id": "zS_RAeh-IfZy",
|
269 |
+
"cellView": "form"
|
270 |
+
},
|
271 |
+
"source": [
|
272 |
+
"#@title STEP3: Select Audio (Record or Upload)\n",
|
273 |
+
"from IPython.display import Audio \n",
|
274 |
+
"from IPython.core.display import display\n",
|
275 |
+
"\n",
|
276 |
+
"record_or_upload = 'Upload' #@param ['Record', 'Upload']\n",
|
277 |
+
"\n",
|
278 |
+
"def displayAudio():\n",
|
279 |
+
" display(Audio('/content/sample_data/input_audio.wav'))\n",
|
280 |
+
"if record_or_upload == 'Record':\n",
|
281 |
+
" audio, sr = get_audio()\n",
|
282 |
+
" import scipy\n",
|
283 |
+
" scipy.io.wavfile.write('/content/sample_data/input_audio.wav', sr, audio)\n",
|
284 |
+
"elif record_or_upload == 'Upload':\n",
|
285 |
+
" from google.colab import files\n",
|
286 |
+
" uploaded = files.upload()\n",
|
287 |
+
" for fn in uploaded.keys():\n",
|
288 |
+
" print('User uploaded file \"{name}\" with length {length} bytes'.format(\n",
|
289 |
+
" name=fn, length=len(uploaded[fn])))\n",
|
290 |
+
" \n",
|
291 |
+
" #concider only the first file\n",
|
292 |
+
" audio_file = str(list(uploaded.keys())[0])\n",
|
293 |
+
" \n",
|
294 |
+
" # Load audio with specified sampling rate\n",
|
295 |
+
" import librosa\n",
|
296 |
+
" audio, sr = librosa.load(audio_file, sr=None)\n",
|
297 |
+
" \n",
|
298 |
+
" # Save audio with specified sampling rate\n",
|
299 |
+
" import soundfile as sf\n",
|
300 |
+
" sf.write('/content/sample_data/input_audio.wav', audio, sr, format='wav')\n",
|
301 |
+
" \n",
|
302 |
+
" clear_output()\n",
|
303 |
+
" displayAudio()"
|
304 |
+
],
|
305 |
+
"execution_count": null,
|
306 |
+
"outputs": []
|
307 |
+
},
|
308 |
+
{
|
309 |
+
"cell_type": "code",
|
310 |
+
"metadata": {
|
311 |
+
"id": "BQPLXJ8L0gms",
|
312 |
+
"cellView": "form"
|
313 |
+
},
|
314 |
+
"source": [
|
315 |
+
"#@title STEP4: Start Crunching and Preview Output\n",
|
316 |
+
"#@markdown <b>Note: Only change these, if you have to</b>\n",
|
317 |
+
"pad_top = 0#@param {type:\"integer\"}\n",
|
318 |
+
"pad_bottom = 10#@param {type:\"integer\"}\n",
|
319 |
+
"pad_left = 0#@param {type:\"integer\"}\n",
|
320 |
+
"pad_right = 0#@param {type:\"integer\"}\n",
|
321 |
+
"rescaleFactor = 1#@param {type:\"integer\"}\n",
|
322 |
+
"nosmooth = False #@param {type:\"boolean\"}\n",
|
323 |
+
"\n",
|
324 |
+
"\n",
|
325 |
+
"if nosmooth == False:\n",
|
326 |
+
" !cd Wav2Lip && python inference.py --checkpoint_path checkpoints/wav2lip_gan.pth --face \"../sample_data/input_vid.mp4\" --audio \"../sample_data/input_audio.wav\" --pads $pad_top $pad_bottom $pad_left $pad_right --resize_factor $rescaleFactor\n",
|
327 |
+
"else:\n",
|
328 |
+
" !cd Wav2Lip && python inference.py --checkpoint_path checkpoints/wav2lip_gan.pth --face \"../sample_data/input_vid.mp4\" --audio \"../sample_data/input_audio.wav\" --pads $pad_top $pad_bottom $pad_left $pad_right --resize_factor $rescaleFactor --nosmooth\n",
|
329 |
+
"#Preview output video\n",
|
330 |
+
"clear_output()\n",
|
331 |
+
"print(\"Final Video Preview\")\n",
|
332 |
+
"print(\"Download this video from\", '/content/Wav2Lip/results/result_voice.mp4')\n",
|
333 |
+
"showVideo('/content/Wav2Lip/results/result_voice.mp4')\n"
|
334 |
+
],
|
335 |
+
"execution_count": null,
|
336 |
+
"outputs": []
|
337 |
+
},
|
338 |
+
{
|
339 |
+
"cell_type": "markdown",
|
340 |
+
"metadata": {
|
341 |
+
"id": "vYxpPeie1CYL"
|
342 |
+
},
|
343 |
+
"source": [
|
344 |
+
"# LipSync on Your Video File"
|
345 |
+
]
|
346 |
+
},
|
347 |
+
{
|
348 |
+
"cell_type": "code",
|
349 |
+
"metadata": {
|
350 |
+
"id": "nDuM7tfZ1F0t",
|
351 |
+
"cellView": "form"
|
352 |
+
},
|
353 |
+
"source": [
|
354 |
+
"import os\n",
|
355 |
+
"from google.colab import files\n",
|
356 |
+
"from IPython.display import HTML\n",
|
357 |
+
"\n",
|
358 |
+
"def showVideo(file_path):\n",
|
359 |
+
" \"\"\"Function to display video in Colab\"\"\"\n",
|
360 |
+
" mp4 = open(file_path,'rb').read()\n",
|
361 |
+
" data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n",
|
362 |
+
" display(HTML(\"\"\"\n",
|
363 |
+
" <video controls width=600>\n",
|
364 |
+
" <source src=\"%s\" type=\"video/mp4\">\n",
|
365 |
+
" </video>\n",
|
366 |
+
" \"\"\" % data_url))\n",
|
367 |
+
"\n",
|
368 |
+
"#@markdown ### Select an uploading method\n",
|
369 |
+
"upload_or_path = \"Upload\" #@param [\"Upload\", \"Custom Path\"]\n",
|
370 |
+
"\n",
|
371 |
+
"if upload_or_path == \"Upload\":\n",
|
372 |
+
" uploaded = files.upload()\n",
|
373 |
+
" for filename in uploaded.keys():\n",
|
374 |
+
" os.rename(filename, '/content/sample_data/input_vid.mp4')\n",
|
375 |
+
" PATH_TO_YOUR_VIDEO = '/content/sample_data/input_vid.mp4'\n",
|
376 |
+
"else:\n",
|
377 |
+
" PATH_TO_YOUR_VIDEO = '/content/test.mp4' #@param {type:\"string\"}\n",
|
378 |
+
" if not os.path.isfile(PATH_TO_YOUR_VIDEO):\n",
|
379 |
+
" print(\"ERROR: File not found!\")\n",
|
380 |
+
" raise SystemExit(0)\n",
|
381 |
+
"\n",
|
382 |
+
"#@markdown ### Trim the video (start, end) seconds\n",
|
383 |
+
"start_time = 0 #@param {type:\"integer\"}\n",
|
384 |
+
"end_time = 0 #@param {type:\"integer\"}\n",
|
385 |
+
"\n",
|
386 |
+
"if start_time == 0 and end_time == 0:\n",
|
387 |
+
" print(\"No trimming applied\")\n",
|
388 |
+
"else:\n",
|
389 |
+
" duration = end_time - start_time\n",
|
390 |
+
" os.system(f\"ffmpeg -i {PATH_TO_YOUR_VIDEO} -ss {start_time} -t {duration} -async 1 /content/sample_data/trimmed_vid.mp4\")\n",
|
391 |
+
" PATH_TO_YOUR_VIDEO = \"/content/sample_data/input_vid.mp4\"\n",
|
392 |
+
" print(f\"Video trimmed from {start_time} to {end_time} seconds\")\n",
|
393 |
+
"\n",
|
394 |
+
"print(f\"PATH_TO_YOUR_VIDEO: {PATH_TO_YOUR_VIDEO}\")\n",
|
395 |
+
"\n",
|
396 |
+
"if upload_or_path == \"Upload\":\n",
|
397 |
+
" clear_output()\n",
|
398 |
+
" print(\"Input Video\")\n",
|
399 |
+
" showVideo(PATH_TO_YOUR_VIDEO)\n",
|
400 |
+
"else:\n",
|
401 |
+
" if os.path.isfile(PATH_TO_YOUR_VIDEO):\n",
|
402 |
+
" print(\"Input Video\")\n",
|
403 |
+
" showVideo(PATH_TO_YOUR_VIDEO)\n"
|
404 |
+
],
|
405 |
+
"execution_count": null,
|
406 |
+
"outputs": []
|
407 |
+
},
|
408 |
+
{
|
409 |
+
"cell_type": "code",
|
410 |
+
"metadata": {
|
411 |
+
"id": "XgF4794r7sWK",
|
412 |
+
"cellView": "form"
|
413 |
+
},
|
414 |
+
"source": [
|
415 |
+
"#@title STEP3: Select Audio (Record or Upload)\n",
|
416 |
+
"from IPython.display import Audio \n",
|
417 |
+
"from IPython.core.display import display\n",
|
418 |
+
"\n",
|
419 |
+
"record_or_upload = 'Upload' #@param ['Record', 'Upload']\n",
|
420 |
+
"\n",
|
421 |
+
"def displayAudio():\n",
|
422 |
+
" display(Audio('/content/sample_data/input_audio.wav'))\n",
|
423 |
+
"if record_or_upload == 'Record':\n",
|
424 |
+
" audio, sr = get_audio()\n",
|
425 |
+
" import scipy\n",
|
426 |
+
" scipy.io.wavfile.write('/content/sample_data/input_audio.wav', sr, audio)\n",
|
427 |
+
"elif record_or_upload == 'Upload':\n",
|
428 |
+
" from google.colab import files\n",
|
429 |
+
" uploaded = files.upload()\n",
|
430 |
+
" for fn in uploaded.keys():\n",
|
431 |
+
" print('User uploaded file \"{name}\" with length {length} bytes'.format(\n",
|
432 |
+
" name=fn, length=len(uploaded[fn])))\n",
|
433 |
+
" \n",
|
434 |
+
" #concider only the first file\n",
|
435 |
+
" audio_file = str(list(uploaded.keys())[0])\n",
|
436 |
+
" \n",
|
437 |
+
" # Load audio with specified sampling rate\n",
|
438 |
+
" import librosa\n",
|
439 |
+
" audio, sr = librosa.load(audio_file, sr=None)\n",
|
440 |
+
" \n",
|
441 |
+
" # Save audio with specified sampling rate\n",
|
442 |
+
" import soundfile as sf\n",
|
443 |
+
" sf.write('/content/sample_data/input_audio.wav', audio, sr, format='wav')\n",
|
444 |
+
" \n",
|
445 |
+
" clear_output()\n",
|
446 |
+
" displayAudio()"
|
447 |
+
],
|
448 |
+
"execution_count": null,
|
449 |
+
"outputs": []
|
450 |
+
},
|
451 |
+
{
|
452 |
+
"cell_type": "code",
|
453 |
+
"metadata": {
|
454 |
+
"id": "ZgtO08V28ANf",
|
455 |
+
"cellView": "form"
|
456 |
+
},
|
457 |
+
"source": [
|
458 |
+
"#@title STEP4: Start Crunching and Preview Output\n",
|
459 |
+
"#@markdown <b>Note: Only change these, if you have to</b>\n",
|
460 |
+
"pad_top = 0#@param {type:\"integer\"}\n",
|
461 |
+
"pad_bottom = 10#@param {type:\"integer\"}\n",
|
462 |
+
"pad_left = 0#@param {type:\"integer\"}\n",
|
463 |
+
"pad_right = 0#@param {type:\"integer\"}\n",
|
464 |
+
"rescaleFactor = 1#@param {type:\"integer\"}\n",
|
465 |
+
"nosmooth = False #@param {type:\"boolean\"}\n",
|
466 |
+
"\n",
|
467 |
+
"if nosmooth == False:\n",
|
468 |
+
" !cd Wav2Lip && python inference.py --checkpoint_path checkpoints/wav2lip_gan.pth --face \"../sample_data/input_vid.mp4\" --audio \"../sample_data/input_audio.wav\" --pads $pad_top $pad_bottom $pad_left $pad_right --resize_factor $rescaleFactor\n",
|
469 |
+
"else:\n",
|
470 |
+
" !cd Wav2Lip && python inference.py --checkpoint_path checkpoints/wav2lip_gan.pth --face \"../sample_data/input_vid.mp4\" --audio \"../sample_data/input_audio.wav\" --pads $pad_top $pad_bottom $pad_left $pad_right --resize_factor $rescaleFactor --nosmooth\n",
|
471 |
+
"\n",
|
472 |
+
"#Preview output video\n",
|
473 |
+
"clear_output()\n",
|
474 |
+
"print(\"Final Video Preview\")\n",
|
475 |
+
"print(\"Dowload this video from\", '/content/Wav2Lip/results/result_voice.mp4')\n",
|
476 |
+
"showVideo('/content/Wav2Lip/results/result_voice.mp4')\n"
|
477 |
+
],
|
478 |
+
"execution_count": null,
|
479 |
+
"outputs": []
|
480 |
+
}
|
481 |
+
]
|
482 |
+
}
|
Wav2Lip_simplified_v5.ipynb
ADDED
@@ -0,0 +1,640 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"nbformat": 4,
|
3 |
+
"nbformat_minor": 0,
|
4 |
+
"metadata": {
|
5 |
+
"colab": {
|
6 |
+
"provenance": [],
|
7 |
+
"private_outputs": true
|
8 |
+
},
|
9 |
+
"kernelspec": {
|
10 |
+
"name": "python3",
|
11 |
+
"display_name": "Python 3"
|
12 |
+
},
|
13 |
+
"accelerator": "GPU"
|
14 |
+
},
|
15 |
+
"cells": [
|
16 |
+
{
|
17 |
+
"cell_type": "markdown",
|
18 |
+
"metadata": {
|
19 |
+
"id": "U1xFNFU58_2j"
|
20 |
+
},
|
21 |
+
"source": [
|
22 |
+
"## Goal: Make anyone speak anything (LipSync)\n",
|
23 |
+
"\n",
|
24 |
+
"* Github: https://github.com/Rudrabha/Wav2Lip\n",
|
25 |
+
"* Paper: https://arxiv.org/abs/2008.10010\n",
|
26 |
+
"*Original notebook: https://colab.research.google.com/drive/1tZpDWXz49W6wDcTprANRGLo2D_EbD5J8?usp=sharing\n",
|
27 |
+
"\n",
|
28 |
+
"\n",
|
29 |
+
"\n",
|
30 |
+
"\n",
|
31 |
+
"**Modded by: [justinjohn-03](https://github.com/justinjohn0306)**\n",
|
32 |
+
"\n",
|
33 |
+
"\n",
|
34 |
+
"\n"
|
35 |
+
]
|
36 |
+
},
|
37 |
+
{
|
38 |
+
"cell_type": "code",
|
39 |
+
"metadata": {
|
40 |
+
"id": "Qgo-oaI3JU2u",
|
41 |
+
"cellView": "form"
|
42 |
+
},
|
43 |
+
"source": [
|
44 |
+
"#@title <h1>Step1: Setup Wav2Lip</h1>\n",
|
45 |
+
"#@markdown * Install dependency\n",
|
46 |
+
"#@markdown * Download pretrained model\n",
|
47 |
+
"from IPython.display import HTML, clear_output\n",
|
48 |
+
"!rm -rf /content/sample_data\n",
|
49 |
+
"!mkdir /content/sample_data\n",
|
50 |
+
"\n",
|
51 |
+
"!git clone https://github.com/justinjohn0306/Wav2Lip\n",
|
52 |
+
"\n",
|
53 |
+
"%cd /content/Wav2Lip\n",
|
54 |
+
"\n",
|
55 |
+
"#download the pretrained model\n",
|
56 |
+
"!wget 'https://github.com/justinjohn0306/Wav2Lip/releases/download/models/wav2lip.pth' -O 'checkpoints/wav2lip.pth'\n",
|
57 |
+
"!wget 'https://github.com/justinjohn0306/Wav2Lip/releases/download/models/wav2lip_gan.pth' -O 'checkpoints/wav2lip_gan.pth'\n",
|
58 |
+
"!wget 'https://github.com/justinjohn0306/Wav2Lip/releases/download/models/resnet50.pth' -O 'checkpoints/resnet50.pth'\n",
|
59 |
+
"!wget 'https://github.com/justinjohn0306/Wav2Lip/releases/download/models/mobilenet.pth' -O 'checkpoints/mobilenet.pth'\n",
|
60 |
+
"a = !pip install https://raw.githubusercontent.com/AwaleSajil/ghc/master/ghc-1.0-py3-none-any.whl\n",
|
61 |
+
"!pip install git+https://github.com/elliottzheng/batch-face.git@master\n",
|
62 |
+
"\n",
|
63 |
+
"!pip install ffmpeg-python mediapipe==0.8.11\n",
|
64 |
+
"\n",
|
65 |
+
"#this code for recording audio\n",
|
66 |
+
"\"\"\"\n",
|
67 |
+
"To write this piece of code I took inspiration/code from a lot of places.\n",
|
68 |
+
"It was late night, so I'm not sure how much I created or just copied o.O\n",
|
69 |
+
"Here are some of the possible references:\n",
|
70 |
+
"https://blog.addpipe.com/recording-audio-in-the-browser-using-pure-html5-and-minimal-javascript/\n",
|
71 |
+
"https://stackoverflow.com/a/18650249\n",
|
72 |
+
"https://hacks.mozilla.org/2014/06/easy-audio-capture-with-the-mediarecorder-api/\n",
|
73 |
+
"https://air.ghost.io/recording-to-an-audio-file-using-html5-and-js/\n",
|
74 |
+
"https://stackoverflow.com/a/49019356\n",
|
75 |
+
"\"\"\"\n",
|
76 |
+
"from IPython.display import HTML, Audio\n",
|
77 |
+
"from google.colab.output import eval_js\n",
|
78 |
+
"from base64 import b64decode\n",
|
79 |
+
"import numpy as np\n",
|
80 |
+
"from scipy.io.wavfile import read as wav_read\n",
|
81 |
+
"import io\n",
|
82 |
+
"import ffmpeg\n",
|
83 |
+
"\n",
|
84 |
+
"AUDIO_HTML = \"\"\"\n",
|
85 |
+
"<script>\n",
|
86 |
+
"var my_div = document.createElement(\"DIV\");\n",
|
87 |
+
"var my_p = document.createElement(\"P\");\n",
|
88 |
+
"var my_btn = document.createElement(\"BUTTON\");\n",
|
89 |
+
"var t = document.createTextNode(\"Press to start recording\");\n",
|
90 |
+
"\n",
|
91 |
+
"my_btn.appendChild(t);\n",
|
92 |
+
"//my_p.appendChild(my_btn);\n",
|
93 |
+
"my_div.appendChild(my_btn);\n",
|
94 |
+
"document.body.appendChild(my_div);\n",
|
95 |
+
"\n",
|
96 |
+
"var base64data = 0;\n",
|
97 |
+
"var reader;\n",
|
98 |
+
"var recorder, gumStream;\n",
|
99 |
+
"var recordButton = my_btn;\n",
|
100 |
+
"\n",
|
101 |
+
"var handleSuccess = function(stream) {\n",
|
102 |
+
" gumStream = stream;\n",
|
103 |
+
" var options = {\n",
|
104 |
+
" //bitsPerSecond: 8000, //chrome seems to ignore, always 48k\n",
|
105 |
+
" mimeType : 'audio/webm;codecs=opus'\n",
|
106 |
+
" //mimeType : 'audio/webm;codecs=pcm'\n",
|
107 |
+
" };\n",
|
108 |
+
" //recorder = new MediaRecorder(stream, options);\n",
|
109 |
+
" recorder = new MediaRecorder(stream);\n",
|
110 |
+
" recorder.ondataavailable = function(e) {\n",
|
111 |
+
" var url = URL.createObjectURL(e.data);\n",
|
112 |
+
" var preview = document.createElement('audio');\n",
|
113 |
+
" preview.controls = true;\n",
|
114 |
+
" preview.src = url;\n",
|
115 |
+
" document.body.appendChild(preview);\n",
|
116 |
+
"\n",
|
117 |
+
" reader = new FileReader();\n",
|
118 |
+
" reader.readAsDataURL(e.data);\n",
|
119 |
+
" reader.onloadend = function() {\n",
|
120 |
+
" base64data = reader.result;\n",
|
121 |
+
" //console.log(\"Inside FileReader:\" + base64data);\n",
|
122 |
+
" }\n",
|
123 |
+
" };\n",
|
124 |
+
" recorder.start();\n",
|
125 |
+
" };\n",
|
126 |
+
"\n",
|
127 |
+
"recordButton.innerText = \"Recording... press to stop\";\n",
|
128 |
+
"\n",
|
129 |
+
"navigator.mediaDevices.getUserMedia({audio: true}).then(handleSuccess);\n",
|
130 |
+
"\n",
|
131 |
+
"\n",
|
132 |
+
"function toggleRecording() {\n",
|
133 |
+
" if (recorder && recorder.state == \"recording\") {\n",
|
134 |
+
" recorder.stop();\n",
|
135 |
+
" gumStream.getAudioTracks()[0].stop();\n",
|
136 |
+
" recordButton.innerText = \"Saving the recording... pls wait!\"\n",
|
137 |
+
" }\n",
|
138 |
+
"}\n",
|
139 |
+
"\n",
|
140 |
+
"// https://stackoverflow.com/a/951057\n",
|
141 |
+
"function sleep(ms) {\n",
|
142 |
+
" return new Promise(resolve => setTimeout(resolve, ms));\n",
|
143 |
+
"}\n",
|
144 |
+
"\n",
|
145 |
+
"var data = new Promise(resolve=>{\n",
|
146 |
+
"//recordButton.addEventListener(\"click\", toggleRecording);\n",
|
147 |
+
"recordButton.onclick = ()=>{\n",
|
148 |
+
"toggleRecording()\n",
|
149 |
+
"\n",
|
150 |
+
"sleep(2000).then(() => {\n",
|
151 |
+
" // wait 2000ms for the data to be available...\n",
|
152 |
+
" // ideally this should use something like await...\n",
|
153 |
+
" //console.log(\"Inside data:\" + base64data)\n",
|
154 |
+
" resolve(base64data.toString())\n",
|
155 |
+
"\n",
|
156 |
+
"});\n",
|
157 |
+
"\n",
|
158 |
+
"}\n",
|
159 |
+
"});\n",
|
160 |
+
"\n",
|
161 |
+
"</script>\n",
|
162 |
+
"\"\"\"\n",
|
163 |
+
"\n",
|
164 |
+
"%cd /\n",
|
165 |
+
"from ghc.l_ghc_cf import l_ghc_cf\n",
|
166 |
+
"%cd content\n",
|
167 |
+
"\n",
|
168 |
+
"def get_audio():\n",
|
169 |
+
" display(HTML(AUDIO_HTML))\n",
|
170 |
+
" data = eval_js(\"data\")\n",
|
171 |
+
" binary = b64decode(data.split(',')[1])\n",
|
172 |
+
"\n",
|
173 |
+
" process = (ffmpeg\n",
|
174 |
+
" .input('pipe:0')\n",
|
175 |
+
" .output('pipe:1', format='wav')\n",
|
176 |
+
" .run_async(pipe_stdin=True, pipe_stdout=True, pipe_stderr=True, quiet=True, overwrite_output=True)\n",
|
177 |
+
" )\n",
|
178 |
+
" output, err = process.communicate(input=binary)\n",
|
179 |
+
"\n",
|
180 |
+
" riff_chunk_size = len(output) - 8\n",
|
181 |
+
" # Break up the chunk size into four bytes, held in b.\n",
|
182 |
+
" q = riff_chunk_size\n",
|
183 |
+
" b = []\n",
|
184 |
+
" for i in range(4):\n",
|
185 |
+
" q, r = divmod(q, 256)\n",
|
186 |
+
" b.append(r)\n",
|
187 |
+
"\n",
|
188 |
+
" # Replace bytes 4:8 in proc.stdout with the actual size of the RIFF chunk.\n",
|
189 |
+
" riff = output[:4] + bytes(b) + output[8:]\n",
|
190 |
+
"\n",
|
191 |
+
" sr, audio = wav_read(io.BytesIO(riff))\n",
|
192 |
+
"\n",
|
193 |
+
" return audio, sr\n",
|
194 |
+
"\n",
|
195 |
+
"\n",
|
196 |
+
"from IPython.display import HTML\n",
|
197 |
+
"from base64 import b64encode\n",
|
198 |
+
"def showVideo(path):\n",
|
199 |
+
" mp4 = open(str(path),'rb').read()\n",
|
200 |
+
" data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n",
|
201 |
+
" return HTML(\"\"\"\n",
|
202 |
+
" <video width=700 controls>\n",
|
203 |
+
" <source src=\"%s\" type=\"video/mp4\">\n",
|
204 |
+
" </video>\n",
|
205 |
+
" \"\"\" % data_url)\n",
|
206 |
+
"\n",
|
207 |
+
"from IPython.display import clear_output\n",
|
208 |
+
"\n",
|
209 |
+
"clear_output()\n",
|
210 |
+
"print(\"All set and ready!\")"
|
211 |
+
],
|
212 |
+
"execution_count": null,
|
213 |
+
"outputs": []
|
214 |
+
},
|
215 |
+
{
|
216 |
+
"cell_type": "markdown",
|
217 |
+
"metadata": {
|
218 |
+
"id": "SEdy6PWDXMRL"
|
219 |
+
},
|
220 |
+
"source": [
|
221 |
+
"# LipSync Youtube Video"
|
222 |
+
]
|
223 |
+
},
|
224 |
+
{
|
225 |
+
"cell_type": "code",
|
226 |
+
"metadata": {
|
227 |
+
"id": "QI4kcm8QEeGZ",
|
228 |
+
"cellView": "form"
|
229 |
+
},
|
230 |
+
"source": [
|
231 |
+
"#@title STEP2: Select a Youtube Video\n",
|
232 |
+
"# Install yt-dlp\n",
|
233 |
+
"\n",
|
234 |
+
"import os\n",
|
235 |
+
"!pip install yt-dlp\n",
|
236 |
+
"\n",
|
237 |
+
"#@markdown ## Find YouTube video ID from URL\n",
|
238 |
+
"\n",
|
239 |
+
"#@markdown ___\n",
|
240 |
+
"\n",
|
241 |
+
"#@markdown Link format:\n",
|
242 |
+
"\n",
|
243 |
+
"#@markdown ``https://youtu.be/vAnWYLTdvfY`` ❌\n",
|
244 |
+
"\n",
|
245 |
+
"#@markdown ``https://www.youtube.com/watch?v=vAnWYLTdvfY`` ✔️\n",
|
246 |
+
"\n",
|
247 |
+
"!rm -df youtube.mp4\n",
|
248 |
+
"\n",
|
249 |
+
"#@markdown ___\n",
|
250 |
+
"from urllib import parse as urlparse\n",
|
251 |
+
"YOUTUBE_URL = 'https://www.youtube.com/watch?v=vAnWYLTdvfY' #@param {type:\"string\"}\n",
|
252 |
+
"url_data = urlparse.urlparse(YOUTUBE_URL)\n",
|
253 |
+
"query = urlparse.parse_qs(url_data.query)\n",
|
254 |
+
"YOUTUBE_ID = query[\"v\"][0]\n",
|
255 |
+
"\n",
|
256 |
+
"\n",
|
257 |
+
"# remove previous input video\n",
|
258 |
+
"!rm -f /content/sample_data/input_vid.mp4\n",
|
259 |
+
"\n",
|
260 |
+
"\n",
|
261 |
+
"#@markdown ___\n",
|
262 |
+
"\n",
|
263 |
+
"#@markdown ### Trim the video (start, end) seconds\n",
|
264 |
+
"start = 35 #@param {type:\"integer\"}\n",
|
265 |
+
"end = 62 #@param {type:\"integer\"}\n",
|
266 |
+
"interval = end - start\n",
|
267 |
+
"\n",
|
268 |
+
"#@markdown <font color=\"orange\"> Note: ``the trimmed video must have face on all frames``\n",
|
269 |
+
"\n",
|
270 |
+
"# Download the YouTube video using yt-dlp\n",
|
271 |
+
"!yt-dlp -f 'bestvideo[ext=mp4]' --output \"youtube.%(ext)s\" https://www.youtube.com/watch?v=$YOUTUBE_ID\n",
|
272 |
+
"\n",
|
273 |
+
"# Cut the video using FFmpeg\n",
|
274 |
+
"!ffmpeg -y -i youtube.mp4 -ss {start} -t {interval} -async 1 /content/sample_data/input_vid.mp4\n",
|
275 |
+
"\n",
|
276 |
+
"# Preview the trimmed video\n",
|
277 |
+
"from IPython.display import HTML\n",
|
278 |
+
"from base64 import b64encode\n",
|
279 |
+
"mp4 = open('/content/sample_data/input_vid.mp4','rb').read()\n",
|
280 |
+
"data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n",
|
281 |
+
"HTML(f\"\"\"<video width=600 controls><source src=\"{data_url}\"></video>\"\"\")\n",
|
282 |
+
"\n"
|
283 |
+
],
|
284 |
+
"execution_count": null,
|
285 |
+
"outputs": []
|
286 |
+
},
|
287 |
+
{
|
288 |
+
"cell_type": "code",
|
289 |
+
"metadata": {
|
290 |
+
"id": "zS_RAeh-IfZy",
|
291 |
+
"cellView": "form"
|
292 |
+
},
|
293 |
+
"source": [
|
294 |
+
"#@title STEP3: Select Audio (Record, Upload from local drive or Gdrive)\n",
|
295 |
+
"import os\n",
|
296 |
+
"from IPython.display import Audio\n",
|
297 |
+
"from IPython.core.display import display\n",
|
298 |
+
"\n",
|
299 |
+
"upload_method = 'Upload' #@param ['Record', 'Upload', 'Custom Path']\n",
|
300 |
+
"\n",
|
301 |
+
"#remove previous input audio\n",
|
302 |
+
"if os.path.isfile('/content/sample_data/input_audio.wav'):\n",
|
303 |
+
" os.remove('/content/sample_data/input_audio.wav')\n",
|
304 |
+
"\n",
|
305 |
+
"def displayAudio():\n",
|
306 |
+
" display(Audio('/content/sample_data/input_audio.wav'))\n",
|
307 |
+
"\n",
|
308 |
+
"if upload_method == 'Record':\n",
|
309 |
+
" audio, sr = get_audio()\n",
|
310 |
+
" import scipy\n",
|
311 |
+
" scipy.io.wavfile.write('/content/sample_data/input_audio.wav', sr, audio)\n",
|
312 |
+
"\n",
|
313 |
+
"elif upload_method == 'Upload':\n",
|
314 |
+
" from google.colab import files\n",
|
315 |
+
" uploaded = files.upload()\n",
|
316 |
+
" for fn in uploaded.keys():\n",
|
317 |
+
" print('User uploaded file \"{name}\" with length {length} bytes'.format(\n",
|
318 |
+
" name=fn, length=len(uploaded[fn])))\n",
|
319 |
+
"\n",
|
320 |
+
" # Consider only the first file\n",
|
321 |
+
" PATH_TO_YOUR_AUDIO = str(list(uploaded.keys())[0])\n",
|
322 |
+
"\n",
|
323 |
+
" # Load audio with specified sampling rate\n",
|
324 |
+
" import librosa\n",
|
325 |
+
" audio, sr = librosa.load(PATH_TO_YOUR_AUDIO, sr=None)\n",
|
326 |
+
"\n",
|
327 |
+
" # Save audio with specified sampling rate\n",
|
328 |
+
" import soundfile as sf\n",
|
329 |
+
" sf.write('/content/sample_data/input_audio.wav', audio, sr, format='wav')\n",
|
330 |
+
"\n",
|
331 |
+
" clear_output()\n",
|
332 |
+
" displayAudio()\n",
|
333 |
+
"\n",
|
334 |
+
"elif upload_method == 'Custom Path':\n",
|
335 |
+
" from google.colab import drive\n",
|
336 |
+
" drive.mount('/content/drive')\n",
|
337 |
+
" #@markdown ``Add the full path to your audio on your Gdrive`` 👇\n",
|
338 |
+
" PATH_TO_YOUR_AUDIO = '/content/drive/MyDrive/test.wav' #@param {type:\"string\"}\n",
|
339 |
+
"\n",
|
340 |
+
" # Load audio with specified sampling rate\n",
|
341 |
+
" import librosa\n",
|
342 |
+
" audio, sr = librosa.load(PATH_TO_YOUR_AUDIO, sr=None)\n",
|
343 |
+
"\n",
|
344 |
+
" # Save audio with specified sampling rate\n",
|
345 |
+
" import soundfile as sf\n",
|
346 |
+
" sf.write('/content/sample_data/input_audio.wav', audio, sr, format='wav')\n",
|
347 |
+
"\n",
|
348 |
+
" clear_output()\n",
|
349 |
+
" displayAudio()\n"
|
350 |
+
],
|
351 |
+
"execution_count": null,
|
352 |
+
"outputs": []
|
353 |
+
},
|
354 |
+
{
|
355 |
+
"cell_type": "code",
|
356 |
+
"metadata": {
|
357 |
+
"id": "BQPLXJ8L0gms",
|
358 |
+
"cellView": "form"
|
359 |
+
},
|
360 |
+
"source": [
|
361 |
+
"#@title STEP4: Start Crunching and Preview Output\n",
|
362 |
+
"#@markdown <b>Note: Only change these, if you have to</b>\n",
|
363 |
+
"\n",
|
364 |
+
"%cd /content/Wav2Lip\n",
|
365 |
+
"\n",
|
366 |
+
"# Set up paths and variables for the output file\n",
|
367 |
+
"output_file_path = '/content/Wav2Lip/results/result_voice.mp4'\n",
|
368 |
+
"\n",
|
369 |
+
"# Delete existing output file before processing, if any\n",
|
370 |
+
"if os.path.exists(output_file_path):\n",
|
371 |
+
" os.remove(output_file_path)\n",
|
372 |
+
"\n",
|
373 |
+
"pad_top = 0#@param {type:\"integer\"}\n",
|
374 |
+
"pad_bottom = 10#@param {type:\"integer\"}\n",
|
375 |
+
"pad_left = 0#@param {type:\"integer\"}\n",
|
376 |
+
"pad_right = 0#@param {type:\"integer\"}\n",
|
377 |
+
"rescaleFactor = 1#@param {type:\"integer\"}\n",
|
378 |
+
"nosmooth = True #@param {type:\"boolean\"}\n",
|
379 |
+
"#@markdown ___\n",
|
380 |
+
"#@markdown Model selection:\n",
|
381 |
+
"use_hd_model = False #@param {type:\"boolean\"}\n",
|
382 |
+
"checkpoint_path = 'checkpoints/wav2lip.pth' if not use_hd_model else 'checkpoints/wav2lip_gan.pth'\n",
|
383 |
+
"\n",
|
384 |
+
"\n",
|
385 |
+
"if nosmooth == False:\n",
|
386 |
+
" !python inference.py --checkpoint_path $checkpoint_path --face \"../sample_data/input_vid.mp4\" --audio \"../sample_data/input_audio.wav\" --pads $pad_top $pad_bottom $pad_left $pad_right --resize_factor $rescaleFactor\n",
|
387 |
+
"else:\n",
|
388 |
+
" !python inference.py --checkpoint_path $checkpoint_path --face \"../sample_data/input_vid.mp4\" --audio \"../sample_data/input_audio.wav\" --pads $pad_top $pad_bottom $pad_left $pad_right --resize_factor $rescaleFactor --nosmooth\n",
|
389 |
+
"\n",
|
390 |
+
"#Preview output video\n",
|
391 |
+
"if os.path.exists(output_file_path):\n",
|
392 |
+
" clear_output()\n",
|
393 |
+
" print(\"Final Video Preview\")\n",
|
394 |
+
" print(\"Download this video from\", output_file_path)\n",
|
395 |
+
" showVideo(output_file_path)\n",
|
396 |
+
"else:\n",
|
397 |
+
" print(\"Processing failed. Output video not found.\")"
|
398 |
+
],
|
399 |
+
"execution_count": null,
|
400 |
+
"outputs": []
|
401 |
+
},
|
402 |
+
{
|
403 |
+
"cell_type": "markdown",
|
404 |
+
"metadata": {
|
405 |
+
"id": "vYxpPeie1CYL"
|
406 |
+
},
|
407 |
+
"source": [
|
408 |
+
"# LipSync on Your Video File"
|
409 |
+
]
|
410 |
+
},
|
411 |
+
{
|
412 |
+
"cell_type": "code",
|
413 |
+
"metadata": {
|
414 |
+
"id": "nDuM7tfZ1F0t",
|
415 |
+
"cellView": "form"
|
416 |
+
},
|
417 |
+
"source": [
|
418 |
+
"import os\n",
|
419 |
+
"import shutil\n",
|
420 |
+
"from google.colab import drive\n",
|
421 |
+
"from google.colab import files\n",
|
422 |
+
"from IPython.display import HTML, clear_output\n",
|
423 |
+
"from base64 import b64encode\n",
|
424 |
+
"import moviepy.editor as mp\n",
|
425 |
+
"\n",
|
426 |
+
"\n",
|
427 |
+
"def showVideo(file_path):\n",
|
428 |
+
" \"\"\"Function to display video in Colab\"\"\"\n",
|
429 |
+
" mp4 = open(file_path,'rb').read()\n",
|
430 |
+
" data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n",
|
431 |
+
" display(HTML(\"\"\"\n",
|
432 |
+
" <video controls width=600>\n",
|
433 |
+
" <source src=\"%s\" type=\"video/mp4\">\n",
|
434 |
+
" </video>\n",
|
435 |
+
" \"\"\" % data_url))\n",
|
436 |
+
"\n",
|
437 |
+
"def get_video_resolution(video_path):\n",
|
438 |
+
" \"\"\"Function to get the resolution of a video\"\"\"\n",
|
439 |
+
" import cv2\n",
|
440 |
+
" video = cv2.VideoCapture(video_path)\n",
|
441 |
+
" width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))\n",
|
442 |
+
" height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))\n",
|
443 |
+
" return (width, height)\n",
|
444 |
+
"\n",
|
445 |
+
"def resize_video(video_path, new_resolution):\n",
|
446 |
+
" \"\"\"Function to resize a video\"\"\"\n",
|
447 |
+
" import cv2\n",
|
448 |
+
" video = cv2.VideoCapture(video_path)\n",
|
449 |
+
" fourcc = int(video.get(cv2.CAP_PROP_FOURCC))\n",
|
450 |
+
" fps = video.get(cv2.CAP_PROP_FPS)\n",
|
451 |
+
" width, height = new_resolution\n",
|
452 |
+
" output_path = os.path.splitext(video_path)[0] + '_720p.mp4'\n",
|
453 |
+
" writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height))\n",
|
454 |
+
" while True:\n",
|
455 |
+
" success, frame = video.read()\n",
|
456 |
+
" if not success:\n",
|
457 |
+
" break\n",
|
458 |
+
" resized_frame = cv2.resize(frame, new_resolution)\n",
|
459 |
+
" writer.write(resized_frame)\n",
|
460 |
+
" video.release()\n",
|
461 |
+
" writer.release()\n",
|
462 |
+
"\n",
|
463 |
+
"# Mount Google Drive if it's not already mounted\n",
|
464 |
+
"if not os.path.isdir(\"/content/drive/MyDrive\"):\n",
|
465 |
+
" drive.mount('/content/drive', force_remount=True)\n",
|
466 |
+
"\n",
|
467 |
+
"#@markdown ### Select an uploading method\n",
|
468 |
+
"upload_method = \"Upload\" #@param [\"Upload\", \"Custom Path\"]\n",
|
469 |
+
"\n",
|
470 |
+
"\n",
|
471 |
+
"# remove previous input video\n",
|
472 |
+
"if os.path.isfile('/content/sample_data/input_vid.mp4'):\n",
|
473 |
+
" os.remove('/content/sample_data/input_vid.mp4')\n",
|
474 |
+
"\n",
|
475 |
+
"if upload_method == \"Upload\":\n",
|
476 |
+
" uploaded = files.upload()\n",
|
477 |
+
" for filename in uploaded.keys():\n",
|
478 |
+
" os.rename(filename, '/content/sample_data/input_vid.mp4')\n",
|
479 |
+
" PATH_TO_YOUR_VIDEO = '/content/sample_data/input_vid.mp4'\n",
|
480 |
+
"\n",
|
481 |
+
"elif upload_method == 'Custom Path':\n",
|
482 |
+
" #@markdown ``Add the full path to your video on your Gdrive `` 👇\n",
|
483 |
+
" PATH_TO_YOUR_VIDEO = '/content/drive/MyDrive/test.mp4' #@param {type:\"string\"}\n",
|
484 |
+
" if not os.path.isfile(PATH_TO_YOUR_VIDEO):\n",
|
485 |
+
" print(\"ERROR: File not found!\")\n",
|
486 |
+
" raise SystemExit(0)\n",
|
487 |
+
"\n",
|
488 |
+
"#@markdown <font color=\"orange\">Notes:\n",
|
489 |
+
"\n",
|
490 |
+
"#@markdown <font color=\"orange\">. ``If your uploaded video is 1080p or higher resolution, this cell will resize it to 720p.``\n",
|
491 |
+
"\n",
|
492 |
+
"#@markdown <font color=\"orange\">. ``Do not upload videos longer than 60 seconds.``\n",
|
493 |
+
"\n",
|
494 |
+
"#@markdown ___\n",
|
495 |
+
"\n",
|
496 |
+
"video_duration = mp.VideoFileClip(PATH_TO_YOUR_VIDEO).duration\n",
|
497 |
+
"if video_duration > 60:\n",
|
498 |
+
" print(\"WARNING: Video duration exceeds 60 seconds. Please upload a shorter video.\")\n",
|
499 |
+
" raise SystemExit(0)\n",
|
500 |
+
"\n",
|
501 |
+
"video_resolution = get_video_resolution(PATH_TO_YOUR_VIDEO)\n",
|
502 |
+
"print(f\"Video resolution: {video_resolution}\")\n",
|
503 |
+
"if video_resolution[0] >= 1920 or video_resolution[1] >= 1080:\n",
|
504 |
+
" print(\"Resizing video to 720p...\")\n",
|
505 |
+
" os.system(f\"ffmpeg -i {PATH_TO_YOUR_VIDEO} -vf scale=1280:720 /content/sample_data/input_vid.mp4\")\n",
|
506 |
+
" PATH_TO_YOUR_VIDEO = \"/content/sample_data/input_vid.mp4\"\n",
|
507 |
+
" print(\"Video resized to 720p\")\n",
|
508 |
+
"else:\n",
|
509 |
+
" print(\"No resizing needed\")\n",
|
510 |
+
"\n",
|
511 |
+
"if upload_method == \"Upload\":\n",
|
512 |
+
" clear_output()\n",
|
513 |
+
" print(\"Input Video\")\n",
|
514 |
+
" showVideo(PATH_TO_YOUR_VIDEO)\n",
|
515 |
+
"else:\n",
|
516 |
+
" if os.path.isfile(PATH_TO_YOUR_VIDEO):\n",
|
517 |
+
" shutil.copyfile(PATH_TO_YOUR_VIDEO, \"/content/sample_data/input_vid.mp4\")\n",
|
518 |
+
" print(\"Input Video\")\n",
|
519 |
+
" showVideo(PATH_TO_YOUR_VIDEO)\n"
|
520 |
+
],
|
521 |
+
"execution_count": null,
|
522 |
+
"outputs": []
|
523 |
+
},
|
524 |
+
{
|
525 |
+
"cell_type": "code",
|
526 |
+
"metadata": {
|
527 |
+
"id": "XgF4794r7sWK",
|
528 |
+
"cellView": "form"
|
529 |
+
},
|
530 |
+
"source": [
|
531 |
+
"#@title STEP3: Select Audio (Record, Upload from local drive or Gdrive)\n",
|
532 |
+
"import os\n",
|
533 |
+
"from IPython.display import Audio\n",
|
534 |
+
"from IPython.core.display import display\n",
|
535 |
+
"\n",
|
536 |
+
"upload_method = 'Upload' #@param ['Record', 'Upload', 'Custom Path']\n",
|
537 |
+
"\n",
|
538 |
+
"#remove previous input audio\n",
|
539 |
+
"if os.path.isfile('/content/sample_data/input_audio.wav'):\n",
|
540 |
+
" os.remove('/content/sample_data/input_audio.wav')\n",
|
541 |
+
"\n",
|
542 |
+
"def displayAudio():\n",
|
543 |
+
" display(Audio('/content/sample_data/input_audio.wav'))\n",
|
544 |
+
"\n",
|
545 |
+
"if upload_method == 'Record':\n",
|
546 |
+
" audio, sr = get_audio()\n",
|
547 |
+
" import scipy\n",
|
548 |
+
" scipy.io.wavfile.write('/content/sample_data/input_audio.wav', sr, audio)\n",
|
549 |
+
"\n",
|
550 |
+
"elif upload_method == 'Upload':\n",
|
551 |
+
" from google.colab import files\n",
|
552 |
+
" uploaded = files.upload()\n",
|
553 |
+
" for fn in uploaded.keys():\n",
|
554 |
+
" print('User uploaded file \"{name}\" with length {length} bytes.'.format(\n",
|
555 |
+
" name=fn, length=len(uploaded[fn])))\n",
|
556 |
+
"\n",
|
557 |
+
" # Consider only the first file\n",
|
558 |
+
" PATH_TO_YOUR_AUDIO = str(list(uploaded.keys())[0])\n",
|
559 |
+
"\n",
|
560 |
+
" # Load audio with specified sampling rate\n",
|
561 |
+
" import librosa\n",
|
562 |
+
" audio, sr = librosa.load(PATH_TO_YOUR_AUDIO, sr=None)\n",
|
563 |
+
"\n",
|
564 |
+
" # Save audio with specified sampling rate\n",
|
565 |
+
" import soundfile as sf\n",
|
566 |
+
" sf.write('/content/sample_data/input_audio.wav', audio, sr, format='wav')\n",
|
567 |
+
"\n",
|
568 |
+
" clear_output()\n",
|
569 |
+
" displayAudio()\n",
|
570 |
+
"\n",
|
571 |
+
"else: # Custom Path\n",
|
572 |
+
" from google.colab import drive\n",
|
573 |
+
" drive.mount('/content/drive')\n",
|
574 |
+
" #@markdown ``Add the full path to your audio on your Gdrive`` 👇\n",
|
575 |
+
" PATH_TO_YOUR_AUDIO = '/content/drive/MyDrive/test.wav' #@param {type:\"string\"}\n",
|
576 |
+
"\n",
|
577 |
+
" # Load audio with specified sampling rate\n",
|
578 |
+
" import librosa\n",
|
579 |
+
" audio, sr = librosa.load(PATH_TO_YOUR_AUDIO, sr=None)\n",
|
580 |
+
"\n",
|
581 |
+
" # Save audio with specified sampling rate\n",
|
582 |
+
" import soundfile as sf\n",
|
583 |
+
" sf.write('/content/sample_data/input_audio.wav', audio, sr, format='wav')\n",
|
584 |
+
"\n",
|
585 |
+
" clear_output()\n",
|
586 |
+
" displayAudio()\n"
|
587 |
+
],
|
588 |
+
"execution_count": null,
|
589 |
+
"outputs": []
|
590 |
+
},
|
591 |
+
{
|
592 |
+
"cell_type": "code",
|
593 |
+
"metadata": {
|
594 |
+
"id": "ZgtO08V28ANf",
|
595 |
+
"cellView": "form"
|
596 |
+
},
|
597 |
+
"source": [
|
598 |
+
"#@title STEP4: Start Crunching and Preview Output\n",
|
599 |
+
"#@markdown <b>Note: Only change these, if you have to</b>\n",
|
600 |
+
"\n",
|
601 |
+
"%cd /content/Wav2Lip\n",
|
602 |
+
"\n",
|
603 |
+
"# Set up paths and variables for the output file\n",
|
604 |
+
"output_file_path = '/content/Wav2Lip/results/result_voice.mp4'\n",
|
605 |
+
"\n",
|
606 |
+
"# Delete existing output file before processing, if any\n",
|
607 |
+
"if os.path.exists(output_file_path):\n",
|
608 |
+
" os.remove(output_file_path)\n",
|
609 |
+
"\n",
|
610 |
+
"pad_top = 0#@param {type:\"integer\"}\n",
|
611 |
+
"pad_bottom = 10#@param {type:\"integer\"}\n",
|
612 |
+
"pad_left = 0#@param {type:\"integer\"}\n",
|
613 |
+
"pad_right = 0#@param {type:\"integer\"}\n",
|
614 |
+
"rescaleFactor = 1#@param {type:\"integer\"}\n",
|
615 |
+
"nosmooth = True #@param {type:\"boolean\"}\n",
|
616 |
+
"#@markdown ___\n",
|
617 |
+
"#@markdown Model selection:\n",
|
618 |
+
"use_hd_model = False #@param {type:\"boolean\"}\n",
|
619 |
+
"checkpoint_path = 'checkpoints/wav2lip.pth' if not use_hd_model else 'checkpoints/wav2lip_gan.pth'\n",
|
620 |
+
"\n",
|
621 |
+
"\n",
|
622 |
+
"if nosmooth == False:\n",
|
623 |
+
" !python inference.py --checkpoint_path $checkpoint_path --face \"../sample_data/input_vid.mp4\" --audio \"../sample_data/input_audio.wav\" --pads $pad_top $pad_bottom $pad_left $pad_right --resize_factor $rescaleFactor\n",
|
624 |
+
"else:\n",
|
625 |
+
" !python inference.py --checkpoint_path $checkpoint_path --face \"../sample_data/input_vid.mp4\" --audio \"../sample_data/input_audio.wav\" --pads $pad_top $pad_bottom $pad_left $pad_right --resize_factor $rescaleFactor --nosmooth\n",
|
626 |
+
"\n",
|
627 |
+
"#Preview output video\n",
|
628 |
+
"if os.path.exists(output_file_path):\n",
|
629 |
+
" clear_output()\n",
|
630 |
+
" print(\"Final Video Preview\")\n",
|
631 |
+
" print(\"Download this video from\", output_file_path)\n",
|
632 |
+
" showVideo(output_file_path)\n",
|
633 |
+
"else:\n",
|
634 |
+
" print(\"Processing failed. Output video not found.\")"
|
635 |
+
],
|
636 |
+
"execution_count": null,
|
637 |
+
"outputs": []
|
638 |
+
}
|
639 |
+
]
|
640 |
+
}
|
app.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import subprocess
|
3 |
+
from subprocess import call
|
4 |
+
with gr.Blocks() as ui:
|
5 |
+
with gr.Row():
|
6 |
+
video = gr.File(label="Video or Image", info="Filepath of video/image that contains faces to use")
|
7 |
+
audio = gr.File(label="Audio", info="Filepath of video/audio file to use as raw audio source")
|
8 |
+
with gr.Column():
|
9 |
+
checkpoint = gr.Radio(["wav2lip", "wav2lip_gan"], label="Checkpoint", info="Name of saved checkpoint to load weights from")
|
10 |
+
no_smooth = gr.Checkbox(label="No Smooth", info="Prevent smoothing face detections over a short temporal window")
|
11 |
+
resize_factor = gr.Slider(minimum=1, maximum=4, step=1, label="Resize Factor", info="Reduce the resolution by this factor. Sometimes, best results are obtained at 480p or 720p")
|
12 |
+
with gr.Row():
|
13 |
+
with gr.Column():
|
14 |
+
pad_top = gr.Slider(minimum=0, maximum=50, step=1, value=0, label="Pad Top", info="Padding above")
|
15 |
+
pad_bottom = gr.Slider(minimum=0, maximum=50, step=1, value=10, label="Pad Bottom (Often increasing this to 20 allows chin to be included)", info="Padding below lips")
|
16 |
+
pad_left = gr.Slider(minimum=0, maximum=50, step=1, value=0, label="Pad Left", info="Padding to the left of lips")
|
17 |
+
pad_right = gr.Slider(minimum=0, maximum=50, step=1, value=0, label="Pad Right", info="Padding to the right of lips")
|
18 |
+
generate_btn = gr.Button("Generate")
|
19 |
+
with gr.Column():
|
20 |
+
result = gr.Video()
|
21 |
+
|
22 |
+
def generate(video, audio, checkpoint, no_smooth, resize_factor, pad_top, pad_bottom, pad_left, pad_right):
|
23 |
+
if video is None or audio is None or checkpoint is None:
|
24 |
+
return
|
25 |
+
|
26 |
+
smooth = "--nosmooth" if no_smooth else ""
|
27 |
+
|
28 |
+
# if nosmooth == False:
|
29 |
+
# !python inference.py --checkpoint_path $checkpoint_path --face "../sample_data/input_vid.mp4" --audio "../sample_data/input_audio.wav" --pads $pad_top $pad_bottom $pad_left $pad_right --resize_factor $rescaleFactor
|
30 |
+
# else:
|
31 |
+
# !python inference.py --checkpoint_path $checkpoint_path --face "../sample_data/input_vid.mp4" --audio "../sample_data/input_audio.wav" --pads $pad_top $pad_bottom $pad_left $pad_right --resize_factor $rescaleFactor --nosmooth
|
32 |
+
|
33 |
+
cmd = f"python inference.py --checkpoint_path {checkpoint} --face {video} --audio {audio} --pads {pad_top} {pad_bottom} {pad_left} {pad_right} --resize_factor {resize_factor} {smooth}"
|
34 |
+
print(cmd)
|
35 |
+
|
36 |
+
call(cmd)
|
37 |
+
return "results/output.mp4"
|
38 |
+
|
39 |
+
generate_btn.click(
|
40 |
+
generate,
|
41 |
+
[video, audio, checkpoint, pad_top, pad_bottom, pad_left, pad_right, resize_factor],
|
42 |
+
result)
|
43 |
+
|
44 |
+
ui.queue().launch(debug=True)
|
audio.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import librosa
|
2 |
+
import librosa.filters
|
3 |
+
import numpy as np
|
4 |
+
# import tensorflow as tf
|
5 |
+
from scipy import signal
|
6 |
+
from scipy.io import wavfile
|
7 |
+
from hparams import hparams as hp
|
8 |
+
|
9 |
+
def load_wav(path, sr):
|
10 |
+
return librosa.core.load(path, sr=sr)[0]
|
11 |
+
|
12 |
+
def save_wav(wav, path, sr):
|
13 |
+
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
|
14 |
+
#proposed by @dsmiller
|
15 |
+
wavfile.write(path, sr, wav.astype(np.int16))
|
16 |
+
|
17 |
+
def save_wavenet_wav(wav, path, sr):
|
18 |
+
librosa.output.write_wav(path, wav, sr=sr)
|
19 |
+
|
20 |
+
def preemphasis(wav, k, preemphasize=True):
|
21 |
+
if preemphasize:
|
22 |
+
return signal.lfilter([1, -k], [1], wav)
|
23 |
+
return wav
|
24 |
+
|
25 |
+
def inv_preemphasis(wav, k, inv_preemphasize=True):
|
26 |
+
if inv_preemphasize:
|
27 |
+
return signal.lfilter([1], [1, -k], wav)
|
28 |
+
return wav
|
29 |
+
|
30 |
+
def get_hop_size():
|
31 |
+
hop_size = hp.hop_size
|
32 |
+
if hop_size is None:
|
33 |
+
assert hp.frame_shift_ms is not None
|
34 |
+
hop_size = int(hp.frame_shift_ms / 1000 * hp.sample_rate)
|
35 |
+
return hop_size
|
36 |
+
|
37 |
+
def linearspectrogram(wav):
|
38 |
+
D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
|
39 |
+
S = _amp_to_db(np.abs(D)) - hp.ref_level_db
|
40 |
+
|
41 |
+
if hp.signal_normalization:
|
42 |
+
return _normalize(S)
|
43 |
+
return S
|
44 |
+
|
45 |
+
def melspectrogram(wav):
|
46 |
+
D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
|
47 |
+
S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp.ref_level_db
|
48 |
+
|
49 |
+
if hp.signal_normalization:
|
50 |
+
return _normalize(S)
|
51 |
+
return S
|
52 |
+
|
53 |
+
def _lws_processor():
|
54 |
+
import lws
|
55 |
+
return lws.lws(hp.n_fft, get_hop_size(), fftsize=hp.win_size, mode="speech")
|
56 |
+
|
57 |
+
def _stft(y):
|
58 |
+
if hp.use_lws:
|
59 |
+
return _lws_processor(hp).stft(y).T
|
60 |
+
else:
|
61 |
+
return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=get_hop_size(), win_length=hp.win_size)
|
62 |
+
|
63 |
+
##########################################################
|
64 |
+
#Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
|
65 |
+
def num_frames(length, fsize, fshift):
|
66 |
+
"""Compute number of time frames of spectrogram
|
67 |
+
"""
|
68 |
+
pad = (fsize - fshift)
|
69 |
+
if length % fshift == 0:
|
70 |
+
M = (length + pad * 2 - fsize) // fshift + 1
|
71 |
+
else:
|
72 |
+
M = (length + pad * 2 - fsize) // fshift + 2
|
73 |
+
return M
|
74 |
+
|
75 |
+
|
76 |
+
def pad_lr(x, fsize, fshift):
|
77 |
+
"""Compute left and right padding
|
78 |
+
"""
|
79 |
+
M = num_frames(len(x), fsize, fshift)
|
80 |
+
pad = (fsize - fshift)
|
81 |
+
T = len(x) + 2 * pad
|
82 |
+
r = (M - 1) * fshift + fsize - T
|
83 |
+
return pad, pad + r
|
84 |
+
##########################################################
|
85 |
+
#Librosa correct padding
|
86 |
+
def librosa_pad_lr(x, fsize, fshift):
|
87 |
+
return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
|
88 |
+
|
89 |
+
# Conversions
|
90 |
+
_mel_basis = None
|
91 |
+
|
92 |
+
def _linear_to_mel(spectogram):
|
93 |
+
global _mel_basis
|
94 |
+
if _mel_basis is None:
|
95 |
+
_mel_basis = _build_mel_basis()
|
96 |
+
return np.dot(_mel_basis, spectogram)
|
97 |
+
|
98 |
+
def _build_mel_basis():
|
99 |
+
assert hp.fmax <= hp.sample_rate // 2
|
100 |
+
return librosa.filters.mel(sr=hp.sample_rate, n_fft=hp.n_fft, n_mels=hp.num_mels,
|
101 |
+
fmin=hp.fmin, fmax=hp.fmax)
|
102 |
+
|
103 |
+
def _amp_to_db(x):
|
104 |
+
min_level = np.exp(hp.min_level_db / 20 * np.log(10))
|
105 |
+
return 20 * np.log10(np.maximum(min_level, x))
|
106 |
+
|
107 |
+
def _db_to_amp(x):
|
108 |
+
return np.power(10.0, (x) * 0.05)
|
109 |
+
|
110 |
+
def _normalize(S):
|
111 |
+
if hp.allow_clipping_in_normalization:
|
112 |
+
if hp.symmetric_mels:
|
113 |
+
return np.clip((2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value,
|
114 |
+
-hp.max_abs_value, hp.max_abs_value)
|
115 |
+
else:
|
116 |
+
return np.clip(hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)), 0, hp.max_abs_value)
|
117 |
+
|
118 |
+
assert S.max() <= 0 and S.min() - hp.min_level_db >= 0
|
119 |
+
if hp.symmetric_mels:
|
120 |
+
return (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value
|
121 |
+
else:
|
122 |
+
return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db))
|
123 |
+
|
124 |
+
def _denormalize(D):
|
125 |
+
if hp.allow_clipping_in_normalization:
|
126 |
+
if hp.symmetric_mels:
|
127 |
+
return (((np.clip(D, -hp.max_abs_value,
|
128 |
+
hp.max_abs_value) + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value))
|
129 |
+
+ hp.min_level_db)
|
130 |
+
else:
|
131 |
+
return ((np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
|
132 |
+
|
133 |
+
if hp.symmetric_mels:
|
134 |
+
return (((D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + hp.min_level_db)
|
135 |
+
else:
|
136 |
+
return ((D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
|
checkpoints/README.md
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Place all your checkpoints (.pth files) here.
|
cog.yaml
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Configuration for Cog ⚙️
|
2 |
+
# Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md
|
3 |
+
|
4 |
+
image: r8.im/devxpy/cog-wav2lip
|
5 |
+
|
6 |
+
build:
|
7 |
+
# set to true if your model requires a GPU
|
8 |
+
gpu: true
|
9 |
+
cuda: "11.6.2"
|
10 |
+
|
11 |
+
# a list of ubuntu apt packages to install
|
12 |
+
system_packages:
|
13 |
+
- ffmpeg
|
14 |
+
- cmake
|
15 |
+
|
16 |
+
# python version in the form '3.8' or '3.8.12'
|
17 |
+
python_version: "3.8"
|
18 |
+
|
19 |
+
# a list of packages in the format <package-name>==<version>
|
20 |
+
python_packages:
|
21 |
+
- numpy==1.23.4
|
22 |
+
- librosa==0.7.0
|
23 |
+
- opencv-python==4.6.0.66
|
24 |
+
- torch==1.12.1+cu116 --extra-index-url=https://download.pytorch.org/whl/cu116
|
25 |
+
- torchvision==0.13.1+cu116 --extra-index-url=https://download.pytorch.org/whl/cu116
|
26 |
+
- tqdm==4.45.0
|
27 |
+
- numba==0.48
|
28 |
+
- mediapipe==0.8.11
|
29 |
+
|
30 |
+
# commands run after the environment is setup
|
31 |
+
run:
|
32 |
+
- pip install git+https://github.com/elliottzheng/batch-face.git@master
|
33 |
+
|
34 |
+
# predict.py defines how predictions are run on your model
|
35 |
+
predict: "predict.py:Predictor"
|
color_syncnet_train.py
ADDED
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from os.path import dirname, join, basename, isfile
|
2 |
+
from tqdm import tqdm
|
3 |
+
|
4 |
+
from models import SyncNet_color as SyncNet
|
5 |
+
import audio
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch import nn
|
9 |
+
from torch import optim
|
10 |
+
import torch.backends.cudnn as cudnn
|
11 |
+
from torch.utils import data as data_utils
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
from glob import glob
|
15 |
+
|
16 |
+
import os, random, cv2, argparse
|
17 |
+
from hparams import hparams, get_image_list
|
18 |
+
|
19 |
+
parser = argparse.ArgumentParser(description='Code to train the expert lip-sync discriminator')
|
20 |
+
|
21 |
+
parser.add_argument("--data_root", help="Root folder of the preprocessed LRS2 dataset", required=True)
|
22 |
+
|
23 |
+
parser.add_argument('--checkpoint_dir', help='Save checkpoints to this directory', required=True, type=str)
|
24 |
+
parser.add_argument('--checkpoint_path', help='Resumed from this checkpoint', default=None, type=str)
|
25 |
+
|
26 |
+
args = parser.parse_args()
|
27 |
+
|
28 |
+
|
29 |
+
global_step = 0
|
30 |
+
global_epoch = 0
|
31 |
+
use_cuda = torch.cuda.is_available()
|
32 |
+
print('use_cuda: {}'.format(use_cuda))
|
33 |
+
|
34 |
+
syncnet_T = 5
|
35 |
+
syncnet_mel_step_size = 16
|
36 |
+
|
37 |
+
class Dataset(object):
|
38 |
+
def __init__(self, split):
|
39 |
+
self.all_videos = get_image_list(args.data_root, split)
|
40 |
+
|
41 |
+
def get_frame_id(self, frame):
|
42 |
+
return int(basename(frame).split('.')[0])
|
43 |
+
|
44 |
+
def get_window(self, start_frame):
|
45 |
+
start_id = self.get_frame_id(start_frame)
|
46 |
+
vidname = dirname(start_frame)
|
47 |
+
|
48 |
+
window_fnames = []
|
49 |
+
for frame_id in range(start_id, start_id + syncnet_T):
|
50 |
+
frame = join(vidname, '{}.jpg'.format(frame_id))
|
51 |
+
if not isfile(frame):
|
52 |
+
return None
|
53 |
+
window_fnames.append(frame)
|
54 |
+
return window_fnames
|
55 |
+
|
56 |
+
def crop_audio_window(self, spec, start_frame):
|
57 |
+
# num_frames = (T x hop_size * fps) / sample_rate
|
58 |
+
start_frame_num = self.get_frame_id(start_frame)
|
59 |
+
start_idx = int(80. * (start_frame_num / float(hparams.fps)))
|
60 |
+
|
61 |
+
end_idx = start_idx + syncnet_mel_step_size
|
62 |
+
|
63 |
+
return spec[start_idx : end_idx, :]
|
64 |
+
|
65 |
+
|
66 |
+
def __len__(self):
|
67 |
+
return len(self.all_videos)
|
68 |
+
|
69 |
+
def __getitem__(self, idx):
|
70 |
+
while 1:
|
71 |
+
idx = random.randint(0, len(self.all_videos) - 1)
|
72 |
+
vidname = self.all_videos[idx]
|
73 |
+
|
74 |
+
img_names = list(glob(join(vidname, '*.jpg')))
|
75 |
+
if len(img_names) <= 3 * syncnet_T:
|
76 |
+
continue
|
77 |
+
img_name = random.choice(img_names)
|
78 |
+
wrong_img_name = random.choice(img_names)
|
79 |
+
while wrong_img_name == img_name:
|
80 |
+
wrong_img_name = random.choice(img_names)
|
81 |
+
|
82 |
+
if random.choice([True, False]):
|
83 |
+
y = torch.ones(1).float()
|
84 |
+
chosen = img_name
|
85 |
+
else:
|
86 |
+
y = torch.zeros(1).float()
|
87 |
+
chosen = wrong_img_name
|
88 |
+
|
89 |
+
window_fnames = self.get_window(chosen)
|
90 |
+
if window_fnames is None:
|
91 |
+
continue
|
92 |
+
|
93 |
+
window = []
|
94 |
+
all_read = True
|
95 |
+
for fname in window_fnames:
|
96 |
+
img = cv2.imread(fname)
|
97 |
+
if img is None:
|
98 |
+
all_read = False
|
99 |
+
break
|
100 |
+
try:
|
101 |
+
img = cv2.resize(img, (hparams.img_size, hparams.img_size))
|
102 |
+
except Exception as e:
|
103 |
+
all_read = False
|
104 |
+
break
|
105 |
+
|
106 |
+
window.append(img)
|
107 |
+
|
108 |
+
if not all_read: continue
|
109 |
+
|
110 |
+
try:
|
111 |
+
wavpath = join(vidname, "audio.wav")
|
112 |
+
wav = audio.load_wav(wavpath, hparams.sample_rate)
|
113 |
+
|
114 |
+
orig_mel = audio.melspectrogram(wav).T
|
115 |
+
except Exception as e:
|
116 |
+
continue
|
117 |
+
|
118 |
+
mel = self.crop_audio_window(orig_mel.copy(), img_name)
|
119 |
+
|
120 |
+
if (mel.shape[0] != syncnet_mel_step_size):
|
121 |
+
continue
|
122 |
+
|
123 |
+
# H x W x 3 * T
|
124 |
+
x = np.concatenate(window, axis=2) / 255.
|
125 |
+
x = x.transpose(2, 0, 1)
|
126 |
+
x = x[:, x.shape[1]//2:]
|
127 |
+
|
128 |
+
x = torch.FloatTensor(x)
|
129 |
+
mel = torch.FloatTensor(mel.T).unsqueeze(0)
|
130 |
+
|
131 |
+
return x, mel, y
|
132 |
+
|
133 |
+
logloss = nn.BCELoss()
|
134 |
+
def cosine_loss(a, v, y):
|
135 |
+
d = nn.functional.cosine_similarity(a, v)
|
136 |
+
loss = logloss(d.unsqueeze(1), y)
|
137 |
+
|
138 |
+
return loss
|
139 |
+
|
140 |
+
def train(device, model, train_data_loader, test_data_loader, optimizer,
|
141 |
+
checkpoint_dir=None, checkpoint_interval=None, nepochs=None):
|
142 |
+
|
143 |
+
global global_step, global_epoch
|
144 |
+
resumed_step = global_step
|
145 |
+
|
146 |
+
while global_epoch < nepochs:
|
147 |
+
running_loss = 0.
|
148 |
+
prog_bar = tqdm(enumerate(train_data_loader))
|
149 |
+
for step, (x, mel, y) in prog_bar:
|
150 |
+
model.train()
|
151 |
+
optimizer.zero_grad()
|
152 |
+
|
153 |
+
# Transform data to CUDA device
|
154 |
+
x = x.to(device)
|
155 |
+
|
156 |
+
mel = mel.to(device)
|
157 |
+
|
158 |
+
a, v = model(mel, x)
|
159 |
+
y = y.to(device)
|
160 |
+
|
161 |
+
loss = cosine_loss(a, v, y)
|
162 |
+
loss.backward()
|
163 |
+
optimizer.step()
|
164 |
+
|
165 |
+
global_step += 1
|
166 |
+
cur_session_steps = global_step - resumed_step
|
167 |
+
running_loss += loss.item()
|
168 |
+
|
169 |
+
if global_step == 1 or global_step % checkpoint_interval == 0:
|
170 |
+
save_checkpoint(
|
171 |
+
model, optimizer, global_step, checkpoint_dir, global_epoch)
|
172 |
+
|
173 |
+
if global_step % hparams.syncnet_eval_interval == 0:
|
174 |
+
with torch.no_grad():
|
175 |
+
eval_model(test_data_loader, global_step, device, model, checkpoint_dir)
|
176 |
+
|
177 |
+
prog_bar.set_description('Loss: {}'.format(running_loss / (step + 1)))
|
178 |
+
|
179 |
+
global_epoch += 1
|
180 |
+
|
181 |
+
def eval_model(test_data_loader, global_step, device, model, checkpoint_dir):
|
182 |
+
eval_steps = 1400
|
183 |
+
print('Evaluating for {} steps'.format(eval_steps))
|
184 |
+
losses = []
|
185 |
+
while 1:
|
186 |
+
for step, (x, mel, y) in enumerate(test_data_loader):
|
187 |
+
|
188 |
+
model.eval()
|
189 |
+
|
190 |
+
# Transform data to CUDA device
|
191 |
+
x = x.to(device)
|
192 |
+
|
193 |
+
mel = mel.to(device)
|
194 |
+
|
195 |
+
a, v = model(mel, x)
|
196 |
+
y = y.to(device)
|
197 |
+
|
198 |
+
loss = cosine_loss(a, v, y)
|
199 |
+
losses.append(loss.item())
|
200 |
+
|
201 |
+
if step > eval_steps: break
|
202 |
+
|
203 |
+
averaged_loss = sum(losses) / len(losses)
|
204 |
+
print(averaged_loss)
|
205 |
+
|
206 |
+
return
|
207 |
+
|
208 |
+
def save_checkpoint(model, optimizer, step, checkpoint_dir, epoch):
|
209 |
+
|
210 |
+
checkpoint_path = join(
|
211 |
+
checkpoint_dir, "checkpoint_step{:09d}.pth".format(global_step))
|
212 |
+
optimizer_state = optimizer.state_dict() if hparams.save_optimizer_state else None
|
213 |
+
torch.save({
|
214 |
+
"state_dict": model.state_dict(),
|
215 |
+
"optimizer": optimizer_state,
|
216 |
+
"global_step": step,
|
217 |
+
"global_epoch": epoch,
|
218 |
+
}, checkpoint_path)
|
219 |
+
print("Saved checkpoint:", checkpoint_path)
|
220 |
+
|
221 |
+
def _load(checkpoint_path):
|
222 |
+
if use_cuda:
|
223 |
+
checkpoint = torch.load(checkpoint_path)
|
224 |
+
else:
|
225 |
+
checkpoint = torch.load(checkpoint_path,
|
226 |
+
map_location=lambda storage, loc: storage)
|
227 |
+
return checkpoint
|
228 |
+
|
229 |
+
def load_checkpoint(path, model, optimizer, reset_optimizer=False):
|
230 |
+
global global_step
|
231 |
+
global global_epoch
|
232 |
+
|
233 |
+
print("Load checkpoint from: {}".format(path))
|
234 |
+
checkpoint = _load(path)
|
235 |
+
model.load_state_dict(checkpoint["state_dict"])
|
236 |
+
if not reset_optimizer:
|
237 |
+
optimizer_state = checkpoint["optimizer"]
|
238 |
+
if optimizer_state is not None:
|
239 |
+
print("Load optimizer state from {}".format(path))
|
240 |
+
optimizer.load_state_dict(checkpoint["optimizer"])
|
241 |
+
global_step = checkpoint["global_step"]
|
242 |
+
global_epoch = checkpoint["global_epoch"]
|
243 |
+
|
244 |
+
return model
|
245 |
+
|
246 |
+
if __name__ == "__main__":
|
247 |
+
checkpoint_dir = args.checkpoint_dir
|
248 |
+
checkpoint_path = args.checkpoint_path
|
249 |
+
|
250 |
+
if not os.path.exists(checkpoint_dir): os.mkdir(checkpoint_dir)
|
251 |
+
|
252 |
+
# Dataset and Dataloader setup
|
253 |
+
train_dataset = Dataset('train')
|
254 |
+
test_dataset = Dataset('val')
|
255 |
+
|
256 |
+
train_data_loader = data_utils.DataLoader(
|
257 |
+
train_dataset, batch_size=hparams.syncnet_batch_size, shuffle=True,
|
258 |
+
num_workers=hparams.num_workers)
|
259 |
+
|
260 |
+
test_data_loader = data_utils.DataLoader(
|
261 |
+
test_dataset, batch_size=hparams.syncnet_batch_size,
|
262 |
+
num_workers=8)
|
263 |
+
|
264 |
+
device = torch.device("cuda" if use_cuda else "cpu")
|
265 |
+
|
266 |
+
# Model
|
267 |
+
model = SyncNet().to(device)
|
268 |
+
print('total trainable params {}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad)))
|
269 |
+
|
270 |
+
optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad],
|
271 |
+
lr=hparams.syncnet_lr)
|
272 |
+
|
273 |
+
if checkpoint_path is not None:
|
274 |
+
load_checkpoint(checkpoint_path, model, optimizer, reset_optimizer=False)
|
275 |
+
|
276 |
+
train(device, model, train_data_loader, test_data_loader, optimizer,
|
277 |
+
checkpoint_dir=checkpoint_dir,
|
278 |
+
checkpoint_interval=hparams.syncnet_checkpoint_interval,
|
279 |
+
nepochs=hparams.nepochs)
|
evaluation/README.md
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Novel Evaluation Framework, new filelists, and using the LSE-D and LSE-C metric.
|
2 |
+
|
3 |
+
Our paper also proposes a novel evaluation framework (Section 4). To evaluate on LRS2, LRS3, and LRW, the filelists are present in the `test_filelists` folder. Please use `gen_videos_from_filelist.py` script to generate the videos. After that, you can calculate the LSE-D and LSE-C scores using the instructions below. Please see [this thread](https://github.com/Rudrabha/Wav2Lip/issues/22#issuecomment-712825380) on how to calculate the FID scores.
|
4 |
+
|
5 |
+
The videos of the ReSyncED benchmark for real-world evaluation will be released soon.
|
6 |
+
|
7 |
+
### Steps to set-up the evaluation repository for LSE-D and LSE-C metric:
|
8 |
+
We use the pre-trained syncnet model available in this [repository](https://github.com/joonson/syncnet_python).
|
9 |
+
|
10 |
+
* Clone the SyncNet repository.
|
11 |
+
```
|
12 |
+
git clone https://github.com/joonson/syncnet_python.git
|
13 |
+
```
|
14 |
+
* Follow the procedure given in the above linked [repository](https://github.com/joonson/syncnet_python) to download the pretrained models and set up the dependencies.
|
15 |
+
* **Note: Please install a separate virtual environment for the evaluation scripts. The versions used by Wav2Lip and the publicly released code of SyncNet is different and can cause version mis-match issues. To avoid this, we suggest the users to install a separate virtual environment for the evaluation scripts**
|
16 |
+
```
|
17 |
+
cd syncnet_python
|
18 |
+
pip install -r requirements.txt
|
19 |
+
sh download_model.sh
|
20 |
+
```
|
21 |
+
* The above step should ensure that all the dependencies required by the repository is installed and the pre-trained models are downloaded.
|
22 |
+
|
23 |
+
### Running the evaluation scripts:
|
24 |
+
* Copy our evaluation scripts given in this folder to the cloned repository.
|
25 |
+
```
|
26 |
+
cd Wav2Lip/evaluation/scores_LSE/
|
27 |
+
cp *.py syncnet_python/
|
28 |
+
cp *.sh syncnet_python/
|
29 |
+
```
|
30 |
+
**Note: We will release the test filelists for LRW, LRS2 and LRS3 shortly once we receive permission from the dataset creators. We will also release the Real World Dataset we have collected shortly.**
|
31 |
+
|
32 |
+
* Our evaluation technique does not require ground-truth of any sorts. Given lip-synced videos we can directly calculate the scores from only the generated videos. Please store the generated videos (from our test sets or your own generated videos) in the following folder structure.
|
33 |
+
```
|
34 |
+
video data root (Folder containing all videos)
|
35 |
+
├── All .mp4 files
|
36 |
+
```
|
37 |
+
* Change the folder back to the cloned repository.
|
38 |
+
```
|
39 |
+
cd syncnet_python
|
40 |
+
```
|
41 |
+
* To run evaluation on the LRW, LRS2 and LRS3 test files, please run the following command:
|
42 |
+
```
|
43 |
+
python calculate_scores_LRS.py --data_root /path/to/video/data/root --tmp_dir tmp_dir/
|
44 |
+
```
|
45 |
+
|
46 |
+
* To run evaluation on the ReSynced dataset or your own generated videos, please run the following command:
|
47 |
+
```
|
48 |
+
sh calculate_scores_real_videos.sh /path/to/video/data/root
|
49 |
+
```
|
50 |
+
* The generated scores will be present in the all_scores.txt generated in the ```syncnet_python/``` folder
|
51 |
+
|
52 |
+
# Evaluation of image quality using FID metric.
|
53 |
+
We use the [pytorch-fid](https://github.com/mseitzer/pytorch-fid) repository for calculating the FID metrics. We dump all the frames in both ground-truth and generated videos and calculate the FID score.
|
54 |
+
|
55 |
+
|
56 |
+
# Opening issues related to evaluation scripts
|
57 |
+
* Please open the issues with the "Evaluation" label if you face any issues in the evaluation scripts.
|
58 |
+
|
59 |
+
# Acknowledgements
|
60 |
+
Our evaluation pipeline in based on two existing repositories. LSE metrics are based on the [syncnet_python](https://github.com/joonson/syncnet_python) repository and the FID score is based on [pytorch-fid](https://github.com/mseitzer/pytorch-fid) repository. We thank the authors of both the repositories for releasing their wonderful code.
|
61 |
+
|
62 |
+
|
63 |
+
|
evaluation/gen_videos_from_filelist.py
ADDED
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from os import listdir, path
|
2 |
+
import numpy as np
|
3 |
+
import scipy, cv2, os, sys, argparse
|
4 |
+
import dlib, json, subprocess
|
5 |
+
from tqdm import tqdm
|
6 |
+
from glob import glob
|
7 |
+
import torch
|
8 |
+
|
9 |
+
sys.path.append('../')
|
10 |
+
import audio
|
11 |
+
import face_detection
|
12 |
+
from models import Wav2Lip
|
13 |
+
|
14 |
+
parser = argparse.ArgumentParser(description='Code to generate results for test filelists')
|
15 |
+
|
16 |
+
parser.add_argument('--filelist', type=str,
|
17 |
+
help='Filepath of filelist file to read', required=True)
|
18 |
+
parser.add_argument('--results_dir', type=str, help='Folder to save all results into',
|
19 |
+
required=True)
|
20 |
+
parser.add_argument('--data_root', type=str, required=True)
|
21 |
+
parser.add_argument('--checkpoint_path', type=str,
|
22 |
+
help='Name of saved checkpoint to load weights from', required=True)
|
23 |
+
|
24 |
+
parser.add_argument('--pads', nargs='+', type=int, default=[0, 0, 0, 0],
|
25 |
+
help='Padding (top, bottom, left, right)')
|
26 |
+
parser.add_argument('--face_det_batch_size', type=int,
|
27 |
+
help='Single GPU batch size for face detection', default=64)
|
28 |
+
parser.add_argument('--wav2lip_batch_size', type=int, help='Batch size for Wav2Lip', default=128)
|
29 |
+
|
30 |
+
# parser.add_argument('--resize_factor', default=1, type=int)
|
31 |
+
|
32 |
+
args = parser.parse_args()
|
33 |
+
args.img_size = 96
|
34 |
+
|
35 |
+
def get_smoothened_boxes(boxes, T):
|
36 |
+
for i in range(len(boxes)):
|
37 |
+
if i + T > len(boxes):
|
38 |
+
window = boxes[len(boxes) - T:]
|
39 |
+
else:
|
40 |
+
window = boxes[i : i + T]
|
41 |
+
boxes[i] = np.mean(window, axis=0)
|
42 |
+
return boxes
|
43 |
+
|
44 |
+
def face_detect(images):
|
45 |
+
batch_size = args.face_det_batch_size
|
46 |
+
|
47 |
+
while 1:
|
48 |
+
predictions = []
|
49 |
+
try:
|
50 |
+
for i in range(0, len(images), batch_size):
|
51 |
+
predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size])))
|
52 |
+
except RuntimeError:
|
53 |
+
if batch_size == 1:
|
54 |
+
raise RuntimeError('Image too big to run face detection on GPU')
|
55 |
+
batch_size //= 2
|
56 |
+
args.face_det_batch_size = batch_size
|
57 |
+
print('Recovering from OOM error; New batch size: {}'.format(batch_size))
|
58 |
+
continue
|
59 |
+
break
|
60 |
+
|
61 |
+
results = []
|
62 |
+
pady1, pady2, padx1, padx2 = args.pads
|
63 |
+
for rect, image in zip(predictions, images):
|
64 |
+
if rect is None:
|
65 |
+
raise ValueError('Face not detected!')
|
66 |
+
|
67 |
+
y1 = max(0, rect[1] - pady1)
|
68 |
+
y2 = min(image.shape[0], rect[3] + pady2)
|
69 |
+
x1 = max(0, rect[0] - padx1)
|
70 |
+
x2 = min(image.shape[1], rect[2] + padx2)
|
71 |
+
|
72 |
+
results.append([x1, y1, x2, y2])
|
73 |
+
|
74 |
+
boxes = get_smoothened_boxes(np.array(results), T=5)
|
75 |
+
results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2), True] for image, (x1, y1, x2, y2) in zip(images, boxes)]
|
76 |
+
|
77 |
+
return results
|
78 |
+
|
79 |
+
def datagen(frames, face_det_results, mels):
|
80 |
+
img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
|
81 |
+
|
82 |
+
for i, m in enumerate(mels):
|
83 |
+
if i >= len(frames): raise ValueError('Equal or less lengths only')
|
84 |
+
|
85 |
+
frame_to_save = frames[i].copy()
|
86 |
+
face, coords, valid_frame = face_det_results[i].copy()
|
87 |
+
if not valid_frame:
|
88 |
+
continue
|
89 |
+
|
90 |
+
face = cv2.resize(face, (args.img_size, args.img_size))
|
91 |
+
|
92 |
+
img_batch.append(face)
|
93 |
+
mel_batch.append(m)
|
94 |
+
frame_batch.append(frame_to_save)
|
95 |
+
coords_batch.append(coords)
|
96 |
+
|
97 |
+
if len(img_batch) >= args.wav2lip_batch_size:
|
98 |
+
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
|
99 |
+
|
100 |
+
img_masked = img_batch.copy()
|
101 |
+
img_masked[:, args.img_size//2:] = 0
|
102 |
+
|
103 |
+
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
|
104 |
+
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
|
105 |
+
|
106 |
+
yield img_batch, mel_batch, frame_batch, coords_batch
|
107 |
+
img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
|
108 |
+
|
109 |
+
if len(img_batch) > 0:
|
110 |
+
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
|
111 |
+
|
112 |
+
img_masked = img_batch.copy()
|
113 |
+
img_masked[:, args.img_size//2:] = 0
|
114 |
+
|
115 |
+
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
|
116 |
+
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
|
117 |
+
|
118 |
+
yield img_batch, mel_batch, frame_batch, coords_batch
|
119 |
+
|
120 |
+
fps = 25
|
121 |
+
mel_step_size = 16
|
122 |
+
mel_idx_multiplier = 80./fps
|
123 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
124 |
+
print('Using {} for inference.'.format(device))
|
125 |
+
|
126 |
+
detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D,
|
127 |
+
flip_input=False, device=device)
|
128 |
+
|
129 |
+
def _load(checkpoint_path):
|
130 |
+
if device == 'cuda':
|
131 |
+
checkpoint = torch.load(checkpoint_path)
|
132 |
+
else:
|
133 |
+
checkpoint = torch.load(checkpoint_path,
|
134 |
+
map_location=lambda storage, loc: storage)
|
135 |
+
return checkpoint
|
136 |
+
|
137 |
+
def load_model(path):
|
138 |
+
model = Wav2Lip()
|
139 |
+
print("Load checkpoint from: {}".format(path))
|
140 |
+
checkpoint = _load(path)
|
141 |
+
s = checkpoint["state_dict"]
|
142 |
+
new_s = {}
|
143 |
+
for k, v in s.items():
|
144 |
+
new_s[k.replace('module.', '')] = v
|
145 |
+
model.load_state_dict(new_s)
|
146 |
+
|
147 |
+
model = model.to(device)
|
148 |
+
return model.eval()
|
149 |
+
|
150 |
+
model = load_model(args.checkpoint_path)
|
151 |
+
|
152 |
+
def main():
|
153 |
+
assert args.data_root is not None
|
154 |
+
data_root = args.data_root
|
155 |
+
|
156 |
+
if not os.path.isdir(args.results_dir): os.makedirs(args.results_dir)
|
157 |
+
|
158 |
+
with open(args.filelist, 'r') as filelist:
|
159 |
+
lines = filelist.readlines()
|
160 |
+
|
161 |
+
for idx, line in enumerate(tqdm(lines)):
|
162 |
+
audio_src, video = line.strip().split()
|
163 |
+
|
164 |
+
audio_src = os.path.join(data_root, audio_src) + '.mp4'
|
165 |
+
video = os.path.join(data_root, video) + '.mp4'
|
166 |
+
|
167 |
+
command = 'ffmpeg -loglevel panic -y -i {} -strict -2 {}'.format(audio_src, '../temp/temp.wav')
|
168 |
+
subprocess.call(command, shell=True)
|
169 |
+
temp_audio = '../temp/temp.wav'
|
170 |
+
|
171 |
+
wav = audio.load_wav(temp_audio, 16000)
|
172 |
+
mel = audio.melspectrogram(wav)
|
173 |
+
if np.isnan(mel.reshape(-1)).sum() > 0:
|
174 |
+
continue
|
175 |
+
|
176 |
+
mel_chunks = []
|
177 |
+
i = 0
|
178 |
+
while 1:
|
179 |
+
start_idx = int(i * mel_idx_multiplier)
|
180 |
+
if start_idx + mel_step_size > len(mel[0]):
|
181 |
+
break
|
182 |
+
mel_chunks.append(mel[:, start_idx : start_idx + mel_step_size])
|
183 |
+
i += 1
|
184 |
+
|
185 |
+
video_stream = cv2.VideoCapture(video)
|
186 |
+
|
187 |
+
full_frames = []
|
188 |
+
while 1:
|
189 |
+
still_reading, frame = video_stream.read()
|
190 |
+
if not still_reading or len(full_frames) > len(mel_chunks):
|
191 |
+
video_stream.release()
|
192 |
+
break
|
193 |
+
full_frames.append(frame)
|
194 |
+
|
195 |
+
if len(full_frames) < len(mel_chunks):
|
196 |
+
continue
|
197 |
+
|
198 |
+
full_frames = full_frames[:len(mel_chunks)]
|
199 |
+
|
200 |
+
try:
|
201 |
+
face_det_results = face_detect(full_frames.copy())
|
202 |
+
except ValueError as e:
|
203 |
+
continue
|
204 |
+
|
205 |
+
batch_size = args.wav2lip_batch_size
|
206 |
+
gen = datagen(full_frames.copy(), face_det_results, mel_chunks)
|
207 |
+
|
208 |
+
for i, (img_batch, mel_batch, frames, coords) in enumerate(gen):
|
209 |
+
if i == 0:
|
210 |
+
frame_h, frame_w = full_frames[0].shape[:-1]
|
211 |
+
out = cv2.VideoWriter('../temp/result.avi',
|
212 |
+
cv2.VideoWriter_fourcc(*'DIVX'), fps, (frame_w, frame_h))
|
213 |
+
|
214 |
+
img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
|
215 |
+
mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
|
216 |
+
|
217 |
+
with torch.no_grad():
|
218 |
+
pred = model(mel_batch, img_batch)
|
219 |
+
|
220 |
+
|
221 |
+
pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
|
222 |
+
|
223 |
+
for pl, f, c in zip(pred, frames, coords):
|
224 |
+
y1, y2, x1, x2 = c
|
225 |
+
pl = cv2.resize(pl.astype(np.uint8), (x2 - x1, y2 - y1))
|
226 |
+
f[y1:y2, x1:x2] = pl
|
227 |
+
out.write(f)
|
228 |
+
|
229 |
+
out.release()
|
230 |
+
|
231 |
+
vid = os.path.join(args.results_dir, '{}.mp4'.format(idx))
|
232 |
+
|
233 |
+
command = 'ffmpeg -loglevel panic -y -i {} -i {} -strict -2 -q:v 1 {}'.format(temp_audio,
|
234 |
+
'../temp/result.avi', vid)
|
235 |
+
subprocess.call(command, shell=True)
|
236 |
+
|
237 |
+
if __name__ == '__main__':
|
238 |
+
main()
|
evaluation/real_videos_inference.py
ADDED
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from os import listdir, path
|
2 |
+
import numpy as np
|
3 |
+
import scipy, cv2, os, sys, argparse
|
4 |
+
import dlib, json, subprocess
|
5 |
+
from tqdm import tqdm
|
6 |
+
from glob import glob
|
7 |
+
import torch
|
8 |
+
|
9 |
+
sys.path.append('../')
|
10 |
+
import audio
|
11 |
+
import face_detection
|
12 |
+
from models import Wav2Lip
|
13 |
+
|
14 |
+
parser = argparse.ArgumentParser(description='Code to generate results on ReSyncED evaluation set')
|
15 |
+
|
16 |
+
parser.add_argument('--mode', type=str,
|
17 |
+
help='random | dubbed | tts', required=True)
|
18 |
+
|
19 |
+
parser.add_argument('--filelist', type=str,
|
20 |
+
help='Filepath of filelist file to read', default=None)
|
21 |
+
|
22 |
+
parser.add_argument('--results_dir', type=str, help='Folder to save all results into',
|
23 |
+
required=True)
|
24 |
+
parser.add_argument('--data_root', type=str, required=True)
|
25 |
+
parser.add_argument('--checkpoint_path', type=str,
|
26 |
+
help='Name of saved checkpoint to load weights from', required=True)
|
27 |
+
parser.add_argument('--pads', nargs='+', type=int, default=[0, 10, 0, 0],
|
28 |
+
help='Padding (top, bottom, left, right)')
|
29 |
+
|
30 |
+
parser.add_argument('--face_det_batch_size', type=int,
|
31 |
+
help='Single GPU batch size for face detection', default=16)
|
32 |
+
|
33 |
+
parser.add_argument('--wav2lip_batch_size', type=int, help='Batch size for Wav2Lip', default=128)
|
34 |
+
parser.add_argument('--face_res', help='Approximate resolution of the face at which to test', default=180)
|
35 |
+
parser.add_argument('--min_frame_res', help='Do not downsample further below this frame resolution', default=480)
|
36 |
+
parser.add_argument('--max_frame_res', help='Downsample to at least this frame resolution', default=720)
|
37 |
+
# parser.add_argument('--resize_factor', default=1, type=int)
|
38 |
+
|
39 |
+
args = parser.parse_args()
|
40 |
+
args.img_size = 96
|
41 |
+
|
42 |
+
def get_smoothened_boxes(boxes, T):
|
43 |
+
for i in range(len(boxes)):
|
44 |
+
if i + T > len(boxes):
|
45 |
+
window = boxes[len(boxes) - T:]
|
46 |
+
else:
|
47 |
+
window = boxes[i : i + T]
|
48 |
+
boxes[i] = np.mean(window, axis=0)
|
49 |
+
return boxes
|
50 |
+
|
51 |
+
def rescale_frames(images):
|
52 |
+
rect = detector.get_detections_for_batch(np.array([images[0]]))[0]
|
53 |
+
if rect is None:
|
54 |
+
raise ValueError('Face not detected!')
|
55 |
+
h, w = images[0].shape[:-1]
|
56 |
+
|
57 |
+
x1, y1, x2, y2 = rect
|
58 |
+
|
59 |
+
face_size = max(np.abs(y1 - y2), np.abs(x1 - x2))
|
60 |
+
|
61 |
+
diff = np.abs(face_size - args.face_res)
|
62 |
+
for factor in range(2, 16):
|
63 |
+
downsampled_res = face_size // factor
|
64 |
+
if min(h//factor, w//factor) < args.min_frame_res: break
|
65 |
+
if np.abs(downsampled_res - args.face_res) >= diff: break
|
66 |
+
|
67 |
+
factor -= 1
|
68 |
+
if factor == 1: return images
|
69 |
+
|
70 |
+
return [cv2.resize(im, (im.shape[1]//(factor), im.shape[0]//(factor))) for im in images]
|
71 |
+
|
72 |
+
|
73 |
+
def face_detect(images):
|
74 |
+
batch_size = args.face_det_batch_size
|
75 |
+
images = rescale_frames(images)
|
76 |
+
|
77 |
+
while 1:
|
78 |
+
predictions = []
|
79 |
+
try:
|
80 |
+
for i in range(0, len(images), batch_size):
|
81 |
+
predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size])))
|
82 |
+
except RuntimeError:
|
83 |
+
if batch_size == 1:
|
84 |
+
raise RuntimeError('Image too big to run face detection on GPU')
|
85 |
+
batch_size //= 2
|
86 |
+
print('Recovering from OOM error; New batch size: {}'.format(batch_size))
|
87 |
+
continue
|
88 |
+
break
|
89 |
+
|
90 |
+
results = []
|
91 |
+
pady1, pady2, padx1, padx2 = args.pads
|
92 |
+
for rect, image in zip(predictions, images):
|
93 |
+
if rect is None:
|
94 |
+
raise ValueError('Face not detected!')
|
95 |
+
|
96 |
+
y1 = max(0, rect[1] - pady1)
|
97 |
+
y2 = min(image.shape[0], rect[3] + pady2)
|
98 |
+
x1 = max(0, rect[0] - padx1)
|
99 |
+
x2 = min(image.shape[1], rect[2] + padx2)
|
100 |
+
|
101 |
+
results.append([x1, y1, x2, y2])
|
102 |
+
|
103 |
+
boxes = get_smoothened_boxes(np.array(results), T=5)
|
104 |
+
results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2), True] for image, (x1, y1, x2, y2) in zip(images, boxes)]
|
105 |
+
|
106 |
+
return results, images
|
107 |
+
|
108 |
+
def datagen(frames, face_det_results, mels):
|
109 |
+
img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
|
110 |
+
|
111 |
+
for i, m in enumerate(mels):
|
112 |
+
if i >= len(frames): raise ValueError('Equal or less lengths only')
|
113 |
+
|
114 |
+
frame_to_save = frames[i].copy()
|
115 |
+
face, coords, valid_frame = face_det_results[i].copy()
|
116 |
+
if not valid_frame:
|
117 |
+
continue
|
118 |
+
|
119 |
+
face = cv2.resize(face, (args.img_size, args.img_size))
|
120 |
+
|
121 |
+
img_batch.append(face)
|
122 |
+
mel_batch.append(m)
|
123 |
+
frame_batch.append(frame_to_save)
|
124 |
+
coords_batch.append(coords)
|
125 |
+
|
126 |
+
if len(img_batch) >= args.wav2lip_batch_size:
|
127 |
+
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
|
128 |
+
|
129 |
+
img_masked = img_batch.copy()
|
130 |
+
img_masked[:, args.img_size//2:] = 0
|
131 |
+
|
132 |
+
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
|
133 |
+
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
|
134 |
+
|
135 |
+
yield img_batch, mel_batch, frame_batch, coords_batch
|
136 |
+
img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
|
137 |
+
|
138 |
+
if len(img_batch) > 0:
|
139 |
+
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
|
140 |
+
|
141 |
+
img_masked = img_batch.copy()
|
142 |
+
img_masked[:, args.img_size//2:] = 0
|
143 |
+
|
144 |
+
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
|
145 |
+
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
|
146 |
+
|
147 |
+
yield img_batch, mel_batch, frame_batch, coords_batch
|
148 |
+
|
149 |
+
def increase_frames(frames, l):
|
150 |
+
## evenly duplicating frames to increase length of video
|
151 |
+
while len(frames) < l:
|
152 |
+
dup_every = float(l) / len(frames)
|
153 |
+
|
154 |
+
final_frames = []
|
155 |
+
next_duplicate = 0.
|
156 |
+
|
157 |
+
for i, f in enumerate(frames):
|
158 |
+
final_frames.append(f)
|
159 |
+
|
160 |
+
if int(np.ceil(next_duplicate)) == i:
|
161 |
+
final_frames.append(f)
|
162 |
+
|
163 |
+
next_duplicate += dup_every
|
164 |
+
|
165 |
+
frames = final_frames
|
166 |
+
|
167 |
+
return frames[:l]
|
168 |
+
|
169 |
+
mel_step_size = 16
|
170 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
171 |
+
print('Using {} for inference.'.format(device))
|
172 |
+
|
173 |
+
detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D,
|
174 |
+
flip_input=False, device=device)
|
175 |
+
|
176 |
+
def _load(checkpoint_path):
|
177 |
+
if device == 'cuda':
|
178 |
+
checkpoint = torch.load(checkpoint_path)
|
179 |
+
else:
|
180 |
+
checkpoint = torch.load(checkpoint_path,
|
181 |
+
map_location=lambda storage, loc: storage)
|
182 |
+
return checkpoint
|
183 |
+
|
184 |
+
def load_model(path):
|
185 |
+
model = Wav2Lip()
|
186 |
+
print("Load checkpoint from: {}".format(path))
|
187 |
+
checkpoint = _load(path)
|
188 |
+
s = checkpoint["state_dict"]
|
189 |
+
new_s = {}
|
190 |
+
for k, v in s.items():
|
191 |
+
new_s[k.replace('module.', '')] = v
|
192 |
+
model.load_state_dict(new_s)
|
193 |
+
|
194 |
+
model = model.to(device)
|
195 |
+
return model.eval()
|
196 |
+
|
197 |
+
model = load_model(args.checkpoint_path)
|
198 |
+
|
199 |
+
def main():
|
200 |
+
if not os.path.isdir(args.results_dir): os.makedirs(args.results_dir)
|
201 |
+
|
202 |
+
if args.mode == 'dubbed':
|
203 |
+
files = listdir(args.data_root)
|
204 |
+
lines = ['{} {}'.format(f, f) for f in files]
|
205 |
+
|
206 |
+
else:
|
207 |
+
assert args.filelist is not None
|
208 |
+
with open(args.filelist, 'r') as filelist:
|
209 |
+
lines = filelist.readlines()
|
210 |
+
|
211 |
+
for idx, line in enumerate(tqdm(lines)):
|
212 |
+
video, audio_src = line.strip().split()
|
213 |
+
|
214 |
+
audio_src = os.path.join(args.data_root, audio_src)
|
215 |
+
video = os.path.join(args.data_root, video)
|
216 |
+
|
217 |
+
command = 'ffmpeg -loglevel panic -y -i {} -strict -2 {}'.format(audio_src, '../temp/temp.wav')
|
218 |
+
subprocess.call(command, shell=True)
|
219 |
+
temp_audio = '../temp/temp.wav'
|
220 |
+
|
221 |
+
wav = audio.load_wav(temp_audio, 16000)
|
222 |
+
mel = audio.melspectrogram(wav)
|
223 |
+
|
224 |
+
if np.isnan(mel.reshape(-1)).sum() > 0:
|
225 |
+
raise ValueError('Mel contains nan!')
|
226 |
+
|
227 |
+
video_stream = cv2.VideoCapture(video)
|
228 |
+
|
229 |
+
fps = video_stream.get(cv2.CAP_PROP_FPS)
|
230 |
+
mel_idx_multiplier = 80./fps
|
231 |
+
|
232 |
+
full_frames = []
|
233 |
+
while 1:
|
234 |
+
still_reading, frame = video_stream.read()
|
235 |
+
if not still_reading:
|
236 |
+
video_stream.release()
|
237 |
+
break
|
238 |
+
|
239 |
+
if min(frame.shape[:-1]) > args.max_frame_res:
|
240 |
+
h, w = frame.shape[:-1]
|
241 |
+
scale_factor = min(h, w) / float(args.max_frame_res)
|
242 |
+
h = int(h/scale_factor)
|
243 |
+
w = int(w/scale_factor)
|
244 |
+
|
245 |
+
frame = cv2.resize(frame, (w, h))
|
246 |
+
full_frames.append(frame)
|
247 |
+
|
248 |
+
mel_chunks = []
|
249 |
+
i = 0
|
250 |
+
while 1:
|
251 |
+
start_idx = int(i * mel_idx_multiplier)
|
252 |
+
if start_idx + mel_step_size > len(mel[0]):
|
253 |
+
break
|
254 |
+
mel_chunks.append(mel[:, start_idx : start_idx + mel_step_size])
|
255 |
+
i += 1
|
256 |
+
|
257 |
+
if len(full_frames) < len(mel_chunks):
|
258 |
+
if args.mode == 'tts':
|
259 |
+
full_frames = increase_frames(full_frames, len(mel_chunks))
|
260 |
+
else:
|
261 |
+
raise ValueError('#Frames, audio length mismatch')
|
262 |
+
|
263 |
+
else:
|
264 |
+
full_frames = full_frames[:len(mel_chunks)]
|
265 |
+
|
266 |
+
try:
|
267 |
+
face_det_results, full_frames = face_detect(full_frames.copy())
|
268 |
+
except ValueError as e:
|
269 |
+
continue
|
270 |
+
|
271 |
+
batch_size = args.wav2lip_batch_size
|
272 |
+
gen = datagen(full_frames.copy(), face_det_results, mel_chunks)
|
273 |
+
|
274 |
+
for i, (img_batch, mel_batch, frames, coords) in enumerate(gen):
|
275 |
+
if i == 0:
|
276 |
+
frame_h, frame_w = full_frames[0].shape[:-1]
|
277 |
+
|
278 |
+
out = cv2.VideoWriter('../temp/result.avi',
|
279 |
+
cv2.VideoWriter_fourcc(*'DIVX'), fps, (frame_w, frame_h))
|
280 |
+
|
281 |
+
img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
|
282 |
+
mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
|
283 |
+
|
284 |
+
with torch.no_grad():
|
285 |
+
pred = model(mel_batch, img_batch)
|
286 |
+
|
287 |
+
|
288 |
+
pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
|
289 |
+
|
290 |
+
for pl, f, c in zip(pred, frames, coords):
|
291 |
+
y1, y2, x1, x2 = c
|
292 |
+
pl = cv2.resize(pl.astype(np.uint8), (x2 - x1, y2 - y1))
|
293 |
+
f[y1:y2, x1:x2] = pl
|
294 |
+
out.write(f)
|
295 |
+
|
296 |
+
out.release()
|
297 |
+
|
298 |
+
vid = os.path.join(args.results_dir, '{}.mp4'.format(idx))
|
299 |
+
command = 'ffmpeg -loglevel panic -y -i {} -i {} -strict -2 -q:v 1 {}'.format('../temp/temp.wav',
|
300 |
+
'../temp/result.avi', vid)
|
301 |
+
subprocess.call(command, shell=True)
|
302 |
+
|
303 |
+
|
304 |
+
if __name__ == '__main__':
|
305 |
+
main()
|
evaluation/scores_LSE/SyncNetInstance_calc_scores.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python
|
2 |
+
#-*- coding: utf-8 -*-
|
3 |
+
# Video 25 FPS, Audio 16000HZ
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import numpy
|
7 |
+
import time, pdb, argparse, subprocess, os, math, glob
|
8 |
+
import cv2
|
9 |
+
import python_speech_features
|
10 |
+
|
11 |
+
from scipy import signal
|
12 |
+
from scipy.io import wavfile
|
13 |
+
from SyncNetModel import *
|
14 |
+
from shutil import rmtree
|
15 |
+
|
16 |
+
|
17 |
+
# ==================== Get OFFSET ====================
|
18 |
+
|
19 |
+
def calc_pdist(feat1, feat2, vshift=10):
|
20 |
+
|
21 |
+
win_size = vshift*2+1
|
22 |
+
|
23 |
+
feat2p = torch.nn.functional.pad(feat2,(0,0,vshift,vshift))
|
24 |
+
|
25 |
+
dists = []
|
26 |
+
|
27 |
+
for i in range(0,len(feat1)):
|
28 |
+
|
29 |
+
dists.append(torch.nn.functional.pairwise_distance(feat1[[i],:].repeat(win_size, 1), feat2p[i:i+win_size,:]))
|
30 |
+
|
31 |
+
return dists
|
32 |
+
|
33 |
+
# ==================== MAIN DEF ====================
|
34 |
+
|
35 |
+
class SyncNetInstance(torch.nn.Module):
|
36 |
+
|
37 |
+
def __init__(self, dropout = 0, num_layers_in_fc_layers = 1024):
|
38 |
+
super(SyncNetInstance, self).__init__();
|
39 |
+
|
40 |
+
self.__S__ = S(num_layers_in_fc_layers = num_layers_in_fc_layers).cuda();
|
41 |
+
|
42 |
+
def evaluate(self, opt, videofile):
|
43 |
+
|
44 |
+
self.__S__.eval();
|
45 |
+
|
46 |
+
# ========== ==========
|
47 |
+
# Convert files
|
48 |
+
# ========== ==========
|
49 |
+
|
50 |
+
if os.path.exists(os.path.join(opt.tmp_dir,opt.reference)):
|
51 |
+
rmtree(os.path.join(opt.tmp_dir,opt.reference))
|
52 |
+
|
53 |
+
os.makedirs(os.path.join(opt.tmp_dir,opt.reference))
|
54 |
+
|
55 |
+
command = ("ffmpeg -loglevel error -y -i %s -threads 1 -f image2 %s" % (videofile,os.path.join(opt.tmp_dir,opt.reference,'%06d.jpg')))
|
56 |
+
output = subprocess.call(command, shell=True, stdout=None)
|
57 |
+
|
58 |
+
command = ("ffmpeg -loglevel error -y -i %s -async 1 -ac 1 -vn -acodec pcm_s16le -ar 16000 %s" % (videofile,os.path.join(opt.tmp_dir,opt.reference,'audio.wav')))
|
59 |
+
output = subprocess.call(command, shell=True, stdout=None)
|
60 |
+
|
61 |
+
# ========== ==========
|
62 |
+
# Load video
|
63 |
+
# ========== ==========
|
64 |
+
|
65 |
+
images = []
|
66 |
+
|
67 |
+
flist = glob.glob(os.path.join(opt.tmp_dir,opt.reference,'*.jpg'))
|
68 |
+
flist.sort()
|
69 |
+
|
70 |
+
for fname in flist:
|
71 |
+
img_input = cv2.imread(fname)
|
72 |
+
img_input = cv2.resize(img_input, (224,224)) #HARD CODED, CHANGE BEFORE RELEASE
|
73 |
+
images.append(img_input)
|
74 |
+
|
75 |
+
im = numpy.stack(images,axis=3)
|
76 |
+
im = numpy.expand_dims(im,axis=0)
|
77 |
+
im = numpy.transpose(im,(0,3,4,1,2))
|
78 |
+
|
79 |
+
imtv = torch.autograd.Variable(torch.from_numpy(im.astype(float)).float())
|
80 |
+
|
81 |
+
# ========== ==========
|
82 |
+
# Load audio
|
83 |
+
# ========== ==========
|
84 |
+
|
85 |
+
sample_rate, audio = wavfile.read(os.path.join(opt.tmp_dir,opt.reference,'audio.wav'))
|
86 |
+
mfcc = zip(*python_speech_features.mfcc(audio,sample_rate))
|
87 |
+
mfcc = numpy.stack([numpy.array(i) for i in mfcc])
|
88 |
+
|
89 |
+
cc = numpy.expand_dims(numpy.expand_dims(mfcc,axis=0),axis=0)
|
90 |
+
cct = torch.autograd.Variable(torch.from_numpy(cc.astype(float)).float())
|
91 |
+
|
92 |
+
# ========== ==========
|
93 |
+
# Check audio and video input length
|
94 |
+
# ========== ==========
|
95 |
+
|
96 |
+
#if (float(len(audio))/16000) != (float(len(images))/25) :
|
97 |
+
# print("WARNING: Audio (%.4fs) and video (%.4fs) lengths are different."%(float(len(audio))/16000,float(len(images))/25))
|
98 |
+
|
99 |
+
min_length = min(len(images),math.floor(len(audio)/640))
|
100 |
+
|
101 |
+
# ========== ==========
|
102 |
+
# Generate video and audio feats
|
103 |
+
# ========== ==========
|
104 |
+
|
105 |
+
lastframe = min_length-5
|
106 |
+
im_feat = []
|
107 |
+
cc_feat = []
|
108 |
+
|
109 |
+
tS = time.time()
|
110 |
+
for i in range(0,lastframe,opt.batch_size):
|
111 |
+
|
112 |
+
im_batch = [ imtv[:,:,vframe:vframe+5,:,:] for vframe in range(i,min(lastframe,i+opt.batch_size)) ]
|
113 |
+
im_in = torch.cat(im_batch,0)
|
114 |
+
im_out = self.__S__.forward_lip(im_in.cuda());
|
115 |
+
im_feat.append(im_out.data.cpu())
|
116 |
+
|
117 |
+
cc_batch = [ cct[:,:,:,vframe*4:vframe*4+20] for vframe in range(i,min(lastframe,i+opt.batch_size)) ]
|
118 |
+
cc_in = torch.cat(cc_batch,0)
|
119 |
+
cc_out = self.__S__.forward_aud(cc_in.cuda())
|
120 |
+
cc_feat.append(cc_out.data.cpu())
|
121 |
+
|
122 |
+
im_feat = torch.cat(im_feat,0)
|
123 |
+
cc_feat = torch.cat(cc_feat,0)
|
124 |
+
|
125 |
+
# ========== ==========
|
126 |
+
# Compute offset
|
127 |
+
# ========== ==========
|
128 |
+
|
129 |
+
#print('Compute time %.3f sec.' % (time.time()-tS))
|
130 |
+
|
131 |
+
dists = calc_pdist(im_feat,cc_feat,vshift=opt.vshift)
|
132 |
+
mdist = torch.mean(torch.stack(dists,1),1)
|
133 |
+
|
134 |
+
minval, minidx = torch.min(mdist,0)
|
135 |
+
|
136 |
+
offset = opt.vshift-minidx
|
137 |
+
conf = torch.median(mdist) - minval
|
138 |
+
|
139 |
+
fdist = numpy.stack([dist[minidx].numpy() for dist in dists])
|
140 |
+
# fdist = numpy.pad(fdist, (3,3), 'constant', constant_values=15)
|
141 |
+
fconf = torch.median(mdist).numpy() - fdist
|
142 |
+
fconfm = signal.medfilt(fconf,kernel_size=9)
|
143 |
+
|
144 |
+
numpy.set_printoptions(formatter={'float': '{: 0.3f}'.format})
|
145 |
+
#print('Framewise conf: ')
|
146 |
+
#print(fconfm)
|
147 |
+
#print('AV offset: \t%d \nMin dist: \t%.3f\nConfidence: \t%.3f' % (offset,minval,conf))
|
148 |
+
|
149 |
+
dists_npy = numpy.array([ dist.numpy() for dist in dists ])
|
150 |
+
return offset.numpy(), conf.numpy(), minval.numpy()
|
151 |
+
|
152 |
+
def extract_feature(self, opt, videofile):
|
153 |
+
|
154 |
+
self.__S__.eval();
|
155 |
+
|
156 |
+
# ========== ==========
|
157 |
+
# Load video
|
158 |
+
# ========== ==========
|
159 |
+
cap = cv2.VideoCapture(videofile)
|
160 |
+
|
161 |
+
frame_num = 1;
|
162 |
+
images = []
|
163 |
+
while frame_num:
|
164 |
+
frame_num += 1
|
165 |
+
ret, image = cap.read()
|
166 |
+
if ret == 0:
|
167 |
+
break
|
168 |
+
|
169 |
+
images.append(image)
|
170 |
+
|
171 |
+
im = numpy.stack(images,axis=3)
|
172 |
+
im = numpy.expand_dims(im,axis=0)
|
173 |
+
im = numpy.transpose(im,(0,3,4,1,2))
|
174 |
+
|
175 |
+
imtv = torch.autograd.Variable(torch.from_numpy(im.astype(float)).float())
|
176 |
+
|
177 |
+
# ========== ==========
|
178 |
+
# Generate video feats
|
179 |
+
# ========== ==========
|
180 |
+
|
181 |
+
lastframe = len(images)-4
|
182 |
+
im_feat = []
|
183 |
+
|
184 |
+
tS = time.time()
|
185 |
+
for i in range(0,lastframe,opt.batch_size):
|
186 |
+
|
187 |
+
im_batch = [ imtv[:,:,vframe:vframe+5,:,:] for vframe in range(i,min(lastframe,i+opt.batch_size)) ]
|
188 |
+
im_in = torch.cat(im_batch,0)
|
189 |
+
im_out = self.__S__.forward_lipfeat(im_in.cuda());
|
190 |
+
im_feat.append(im_out.data.cpu())
|
191 |
+
|
192 |
+
im_feat = torch.cat(im_feat,0)
|
193 |
+
|
194 |
+
# ========== ==========
|
195 |
+
# Compute offset
|
196 |
+
# ========== ==========
|
197 |
+
|
198 |
+
print('Compute time %.3f sec.' % (time.time()-tS))
|
199 |
+
|
200 |
+
return im_feat
|
201 |
+
|
202 |
+
|
203 |
+
def loadParameters(self, path):
|
204 |
+
loaded_state = torch.load(path, map_location=lambda storage, loc: storage);
|
205 |
+
|
206 |
+
self_state = self.__S__.state_dict();
|
207 |
+
|
208 |
+
for name, param in loaded_state.items():
|
209 |
+
|
210 |
+
self_state[name].copy_(param);
|
evaluation/scores_LSE/calculate_scores_LRS.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python
|
2 |
+
#-*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
import time, pdb, argparse, subprocess
|
5 |
+
import glob
|
6 |
+
import os
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
from SyncNetInstance_calc_scores import *
|
10 |
+
|
11 |
+
# ==================== LOAD PARAMS ====================
|
12 |
+
|
13 |
+
|
14 |
+
parser = argparse.ArgumentParser(description = "SyncNet");
|
15 |
+
|
16 |
+
parser.add_argument('--initial_model', type=str, default="data/syncnet_v2.model", help='');
|
17 |
+
parser.add_argument('--batch_size', type=int, default='20', help='');
|
18 |
+
parser.add_argument('--vshift', type=int, default='15', help='');
|
19 |
+
parser.add_argument('--data_root', type=str, required=True, help='');
|
20 |
+
parser.add_argument('--tmp_dir', type=str, default="data/work/pytmp", help='');
|
21 |
+
parser.add_argument('--reference', type=str, default="demo", help='');
|
22 |
+
|
23 |
+
opt = parser.parse_args();
|
24 |
+
|
25 |
+
|
26 |
+
# ==================== RUN EVALUATION ====================
|
27 |
+
|
28 |
+
s = SyncNetInstance();
|
29 |
+
|
30 |
+
s.loadParameters(opt.initial_model);
|
31 |
+
#print("Model %s loaded."%opt.initial_model);
|
32 |
+
path = os.path.join(opt.data_root, "*.mp4")
|
33 |
+
|
34 |
+
all_videos = glob.glob(path)
|
35 |
+
|
36 |
+
prog_bar = tqdm(range(len(all_videos)))
|
37 |
+
avg_confidence = 0.
|
38 |
+
avg_min_distance = 0.
|
39 |
+
|
40 |
+
|
41 |
+
for videofile_idx in prog_bar:
|
42 |
+
videofile = all_videos[videofile_idx]
|
43 |
+
offset, confidence, min_distance = s.evaluate(opt, videofile=videofile)
|
44 |
+
avg_confidence += confidence
|
45 |
+
avg_min_distance += min_distance
|
46 |
+
prog_bar.set_description('Avg Confidence: {}, Avg Minimum Dist: {}'.format(round(avg_confidence / (videofile_idx + 1), 3), round(avg_min_distance / (videofile_idx + 1), 3)))
|
47 |
+
prog_bar.refresh()
|
48 |
+
|
49 |
+
print ('Average Confidence: {}'.format(avg_confidence/len(all_videos)))
|
50 |
+
print ('Average Minimum Distance: {}'.format(avg_min_distance/len(all_videos)))
|
51 |
+
|
52 |
+
|
53 |
+
|
evaluation/scores_LSE/calculate_scores_real_videos.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python
|
2 |
+
#-*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
import time, pdb, argparse, subprocess, pickle, os, gzip, glob
|
5 |
+
|
6 |
+
from SyncNetInstance_calc_scores import *
|
7 |
+
|
8 |
+
# ==================== PARSE ARGUMENT ====================
|
9 |
+
|
10 |
+
parser = argparse.ArgumentParser(description = "SyncNet");
|
11 |
+
parser.add_argument('--initial_model', type=str, default="data/syncnet_v2.model", help='');
|
12 |
+
parser.add_argument('--batch_size', type=int, default='20', help='');
|
13 |
+
parser.add_argument('--vshift', type=int, default='15', help='');
|
14 |
+
parser.add_argument('--data_dir', type=str, default='data/work', help='');
|
15 |
+
parser.add_argument('--videofile', type=str, default='', help='');
|
16 |
+
parser.add_argument('--reference', type=str, default='', help='');
|
17 |
+
opt = parser.parse_args();
|
18 |
+
|
19 |
+
setattr(opt,'avi_dir',os.path.join(opt.data_dir,'pyavi'))
|
20 |
+
setattr(opt,'tmp_dir',os.path.join(opt.data_dir,'pytmp'))
|
21 |
+
setattr(opt,'work_dir',os.path.join(opt.data_dir,'pywork'))
|
22 |
+
setattr(opt,'crop_dir',os.path.join(opt.data_dir,'pycrop'))
|
23 |
+
|
24 |
+
|
25 |
+
# ==================== LOAD MODEL AND FILE LIST ====================
|
26 |
+
|
27 |
+
s = SyncNetInstance();
|
28 |
+
|
29 |
+
s.loadParameters(opt.initial_model);
|
30 |
+
#print("Model %s loaded."%opt.initial_model);
|
31 |
+
|
32 |
+
flist = glob.glob(os.path.join(opt.crop_dir,opt.reference,'0*.avi'))
|
33 |
+
flist.sort()
|
34 |
+
|
35 |
+
# ==================== GET OFFSETS ====================
|
36 |
+
|
37 |
+
dists = []
|
38 |
+
for idx, fname in enumerate(flist):
|
39 |
+
offset, conf, dist = s.evaluate(opt,videofile=fname)
|
40 |
+
print (str(dist)+" "+str(conf))
|
41 |
+
|
42 |
+
# ==================== PRINT RESULTS TO FILE ====================
|
43 |
+
|
44 |
+
#with open(os.path.join(opt.work_dir,opt.reference,'activesd.pckl'), 'wb') as fil:
|
45 |
+
# pickle.dump(dists, fil)
|
evaluation/scores_LSE/calculate_scores_real_videos.sh
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
rm all_scores.txt
|
2 |
+
yourfilenames=`ls $1`
|
3 |
+
|
4 |
+
for eachfile in $yourfilenames
|
5 |
+
do
|
6 |
+
python run_pipeline.py --videofile $1/$eachfile --reference wav2lip --data_dir tmp_dir
|
7 |
+
python calculate_scores_real_videos.py --videofile $1/$eachfile --reference wav2lip --data_dir tmp_dir >> all_scores.txt
|
8 |
+
done
|
evaluation/test_filelists/README.md
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
This folder contains the filelists for the new evaluation framework proposed in the paper.
|
2 |
+
|
3 |
+
## Test filelists for LRS2, LRS3, and LRW.
|
4 |
+
|
5 |
+
This folder contains three filelists, each containing a list of names of audio-video pairs from the test sets of LRS2, LRS3, and LRW. The LRS2 and LRW filelists are strictly "Copyright BBC" and can only be used for “non-commercial research by applicants who have an agreement with the BBC to access the Lip Reading in the Wild and/or Lip Reading Sentences in the Wild datasets”. Please follow this link for more details: [https://www.bbc.co.uk/rd/projects/lip-reading-datasets](https://www.bbc.co.uk/rd/projects/lip-reading-datasets).
|
6 |
+
|
7 |
+
|
8 |
+
## ReSynCED benchmark
|
9 |
+
|
10 |
+
The sub-folder `ReSynCED` contains filelists for our own Real-world lip-Sync Evaluation Dataset (ReSyncED).
|
11 |
+
|
12 |
+
|
13 |
+
#### Instructions on how to use the above two filelists are available in the README of the parent folder.
|
evaluation/test_filelists/ReSyncED/random_pairs.txt
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
sachin.mp4 emma_cropped.mp4
|
2 |
+
sachin.mp4 mourinho.mp4
|
3 |
+
sachin.mp4 elon.mp4
|
4 |
+
sachin.mp4 messi2.mp4
|
5 |
+
sachin.mp4 cr1.mp4
|
6 |
+
sachin.mp4 sachin.mp4
|
7 |
+
sachin.mp4 sg.mp4
|
8 |
+
sachin.mp4 fergi.mp4
|
9 |
+
sachin.mp4 spanish_lec1.mp4
|
10 |
+
sachin.mp4 bush_small.mp4
|
11 |
+
sachin.mp4 macca_cut.mp4
|
12 |
+
sachin.mp4 ca_cropped.mp4
|
13 |
+
sachin.mp4 lecun.mp4
|
14 |
+
sachin.mp4 spanish_lec0.mp4
|
15 |
+
srk.mp4 emma_cropped.mp4
|
16 |
+
srk.mp4 mourinho.mp4
|
17 |
+
srk.mp4 elon.mp4
|
18 |
+
srk.mp4 messi2.mp4
|
19 |
+
srk.mp4 cr1.mp4
|
20 |
+
srk.mp4 srk.mp4
|
21 |
+
srk.mp4 sachin.mp4
|
22 |
+
srk.mp4 sg.mp4
|
23 |
+
srk.mp4 fergi.mp4
|
24 |
+
srk.mp4 spanish_lec1.mp4
|
25 |
+
srk.mp4 bush_small.mp4
|
26 |
+
srk.mp4 macca_cut.mp4
|
27 |
+
srk.mp4 ca_cropped.mp4
|
28 |
+
srk.mp4 guardiola.mp4
|
29 |
+
srk.mp4 lecun.mp4
|
30 |
+
srk.mp4 spanish_lec0.mp4
|
31 |
+
cr1.mp4 emma_cropped.mp4
|
32 |
+
cr1.mp4 elon.mp4
|
33 |
+
cr1.mp4 messi2.mp4
|
34 |
+
cr1.mp4 cr1.mp4
|
35 |
+
cr1.mp4 spanish_lec1.mp4
|
36 |
+
cr1.mp4 bush_small.mp4
|
37 |
+
cr1.mp4 macca_cut.mp4
|
38 |
+
cr1.mp4 ca_cropped.mp4
|
39 |
+
cr1.mp4 lecun.mp4
|
40 |
+
cr1.mp4 spanish_lec0.mp4
|
41 |
+
macca_cut.mp4 emma_cropped.mp4
|
42 |
+
macca_cut.mp4 elon.mp4
|
43 |
+
macca_cut.mp4 messi2.mp4
|
44 |
+
macca_cut.mp4 spanish_lec1.mp4
|
45 |
+
macca_cut.mp4 macca_cut.mp4
|
46 |
+
macca_cut.mp4 ca_cropped.mp4
|
47 |
+
macca_cut.mp4 spanish_lec0.mp4
|
48 |
+
lecun.mp4 emma_cropped.mp4
|
49 |
+
lecun.mp4 elon.mp4
|
50 |
+
lecun.mp4 messi2.mp4
|
51 |
+
lecun.mp4 spanish_lec1.mp4
|
52 |
+
lecun.mp4 macca_cut.mp4
|
53 |
+
lecun.mp4 ca_cropped.mp4
|
54 |
+
lecun.mp4 lecun.mp4
|
55 |
+
lecun.mp4 spanish_lec0.mp4
|
56 |
+
messi2.mp4 emma_cropped.mp4
|
57 |
+
messi2.mp4 elon.mp4
|
58 |
+
messi2.mp4 messi2.mp4
|
59 |
+
messi2.mp4 spanish_lec1.mp4
|
60 |
+
messi2.mp4 macca_cut.mp4
|
61 |
+
messi2.mp4 ca_cropped.mp4
|
62 |
+
messi2.mp4 spanish_lec0.mp4
|
63 |
+
ca_cropped.mp4 emma_cropped.mp4
|
64 |
+
ca_cropped.mp4 elon.mp4
|
65 |
+
ca_cropped.mp4 spanish_lec1.mp4
|
66 |
+
ca_cropped.mp4 ca_cropped.mp4
|
67 |
+
ca_cropped.mp4 spanish_lec0.mp4
|
68 |
+
spanish_lec1.mp4 spanish_lec1.mp4
|
69 |
+
spanish_lec1.mp4 spanish_lec0.mp4
|
70 |
+
elon.mp4 elon.mp4
|
71 |
+
elon.mp4 spanish_lec1.mp4
|
72 |
+
elon.mp4 spanish_lec0.mp4
|
73 |
+
guardiola.mp4 emma_cropped.mp4
|
74 |
+
guardiola.mp4 mourinho.mp4
|
75 |
+
guardiola.mp4 elon.mp4
|
76 |
+
guardiola.mp4 messi2.mp4
|
77 |
+
guardiola.mp4 cr1.mp4
|
78 |
+
guardiola.mp4 sachin.mp4
|
79 |
+
guardiola.mp4 sg.mp4
|
80 |
+
guardiola.mp4 fergi.mp4
|
81 |
+
guardiola.mp4 spanish_lec1.mp4
|
82 |
+
guardiola.mp4 bush_small.mp4
|
83 |
+
guardiola.mp4 macca_cut.mp4
|
84 |
+
guardiola.mp4 ca_cropped.mp4
|
85 |
+
guardiola.mp4 guardiola.mp4
|
86 |
+
guardiola.mp4 lecun.mp4
|
87 |
+
guardiola.mp4 spanish_lec0.mp4
|
88 |
+
fergi.mp4 emma_cropped.mp4
|
89 |
+
fergi.mp4 mourinho.mp4
|
90 |
+
fergi.mp4 elon.mp4
|
91 |
+
fergi.mp4 messi2.mp4
|
92 |
+
fergi.mp4 cr1.mp4
|
93 |
+
fergi.mp4 sachin.mp4
|
94 |
+
fergi.mp4 sg.mp4
|
95 |
+
fergi.mp4 fergi.mp4
|
96 |
+
fergi.mp4 spanish_lec1.mp4
|
97 |
+
fergi.mp4 bush_small.mp4
|
98 |
+
fergi.mp4 macca_cut.mp4
|
99 |
+
fergi.mp4 ca_cropped.mp4
|
100 |
+
fergi.mp4 lecun.mp4
|
101 |
+
fergi.mp4 spanish_lec0.mp4
|
102 |
+
spanish.mp4 emma_cropped.mp4
|
103 |
+
spanish.mp4 spanish.mp4
|
104 |
+
spanish.mp4 mourinho.mp4
|
105 |
+
spanish.mp4 elon.mp4
|
106 |
+
spanish.mp4 messi2.mp4
|
107 |
+
spanish.mp4 cr1.mp4
|
108 |
+
spanish.mp4 srk.mp4
|
109 |
+
spanish.mp4 sachin.mp4
|
110 |
+
spanish.mp4 sg.mp4
|
111 |
+
spanish.mp4 fergi.mp4
|
112 |
+
spanish.mp4 spanish_lec1.mp4
|
113 |
+
spanish.mp4 bush_small.mp4
|
114 |
+
spanish.mp4 macca_cut.mp4
|
115 |
+
spanish.mp4 ca_cropped.mp4
|
116 |
+
spanish.mp4 guardiola.mp4
|
117 |
+
spanish.mp4 lecun.mp4
|
118 |
+
spanish.mp4 spanish_lec0.mp4
|
119 |
+
bush_small.mp4 emma_cropped.mp4
|
120 |
+
bush_small.mp4 elon.mp4
|
121 |
+
bush_small.mp4 messi2.mp4
|
122 |
+
bush_small.mp4 spanish_lec1.mp4
|
123 |
+
bush_small.mp4 bush_small.mp4
|
124 |
+
bush_small.mp4 macca_cut.mp4
|
125 |
+
bush_small.mp4 ca_cropped.mp4
|
126 |
+
bush_small.mp4 lecun.mp4
|
127 |
+
bush_small.mp4 spanish_lec0.mp4
|
128 |
+
emma_cropped.mp4 emma_cropped.mp4
|
129 |
+
emma_cropped.mp4 elon.mp4
|
130 |
+
emma_cropped.mp4 spanish_lec1.mp4
|
131 |
+
emma_cropped.mp4 spanish_lec0.mp4
|
132 |
+
sg.mp4 emma_cropped.mp4
|
133 |
+
sg.mp4 mourinho.mp4
|
134 |
+
sg.mp4 elon.mp4
|
135 |
+
sg.mp4 messi2.mp4
|
136 |
+
sg.mp4 cr1.mp4
|
137 |
+
sg.mp4 sachin.mp4
|
138 |
+
sg.mp4 sg.mp4
|
139 |
+
sg.mp4 fergi.mp4
|
140 |
+
sg.mp4 spanish_lec1.mp4
|
141 |
+
sg.mp4 bush_small.mp4
|
142 |
+
sg.mp4 macca_cut.mp4
|
143 |
+
sg.mp4 ca_cropped.mp4
|
144 |
+
sg.mp4 lecun.mp4
|
145 |
+
sg.mp4 spanish_lec0.mp4
|
146 |
+
spanish_lec0.mp4 spanish_lec0.mp4
|
147 |
+
mourinho.mp4 emma_cropped.mp4
|
148 |
+
mourinho.mp4 mourinho.mp4
|
149 |
+
mourinho.mp4 elon.mp4
|
150 |
+
mourinho.mp4 messi2.mp4
|
151 |
+
mourinho.mp4 cr1.mp4
|
152 |
+
mourinho.mp4 sachin.mp4
|
153 |
+
mourinho.mp4 sg.mp4
|
154 |
+
mourinho.mp4 fergi.mp4
|
155 |
+
mourinho.mp4 spanish_lec1.mp4
|
156 |
+
mourinho.mp4 bush_small.mp4
|
157 |
+
mourinho.mp4 macca_cut.mp4
|
158 |
+
mourinho.mp4 ca_cropped.mp4
|
159 |
+
mourinho.mp4 lecun.mp4
|
160 |
+
mourinho.mp4 spanish_lec0.mp4
|
evaluation/test_filelists/ReSyncED/tts_pairs.txt
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
adam_1.mp4 andreng_optimization.wav
|
2 |
+
agad_2.mp4 agad_2.wav
|
3 |
+
agad_1.mp4 agad_1.wav
|
4 |
+
agad_3.mp4 agad_3.wav
|
5 |
+
rms_prop_1.mp4 rms_prop_tts.wav
|
6 |
+
tf_1.mp4 tf_1.wav
|
7 |
+
tf_2.mp4 tf_2.wav
|
8 |
+
andrew_ng_ai_business.mp4 andrewng_business_tts.wav
|
9 |
+
covid_autopsy_1.mp4 autopsy_tts.wav
|
10 |
+
news_1.mp4 news_tts.wav
|
11 |
+
andrew_ng_fund_1.mp4 andrewng_ai_fund.wav
|
12 |
+
covid_treatments_1.mp4 covid_tts.wav
|
13 |
+
pytorch_v_tf.mp4 pytorch_vs_tf_eng.wav
|
14 |
+
pytorch_1.mp4 pytorch.wav
|
15 |
+
pkb_1.mp4 pkb_1.wav
|
16 |
+
ss_1.mp4 ss_1.wav
|
17 |
+
carlsen_1.mp4 carlsen_eng.wav
|
18 |
+
french.mp4 french.wav
|
face_detect.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import mediapipe as mp
|
3 |
+
|
4 |
+
mp_face_mesh = mp.solutions.face_mesh
|
5 |
+
mp_drawing = mp.solutions.drawing_utils
|
6 |
+
mp_drawing_styles = mp.solutions.drawing_styles
|
7 |
+
mp_face_detection = mp.solutions.face_detection
|
8 |
+
|
9 |
+
|
10 |
+
def face_rect(images):
|
11 |
+
with mp_face_detection.FaceDetection(
|
12 |
+
model_selection=1, min_detection_confidence=0.5
|
13 |
+
) as face_detection:
|
14 |
+
for image_cv2 in images:
|
15 |
+
# Convert the BGR image to RGB and process it with MediaPipe Face Detection.
|
16 |
+
results = face_detection.process(cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB))
|
17 |
+
|
18 |
+
# Draw face detections of each face.
|
19 |
+
if not results.detections:
|
20 |
+
yield None
|
21 |
+
for detection in results.detections:
|
22 |
+
yield _get_bounding_rect(image_cv2, detection)
|
23 |
+
|
24 |
+
|
25 |
+
def _get_bounding_rect(
|
26 |
+
image: mp_drawing.np.ndarray,
|
27 |
+
detection: mp_drawing.detection_pb2.Detection,
|
28 |
+
):
|
29 |
+
"""
|
30 |
+
Stolen from mediapipe.solutions.drawing_utils.draw_detection()
|
31 |
+
"""
|
32 |
+
if not detection.location_data:
|
33 |
+
return
|
34 |
+
if image.shape[2] != mp_drawing._BGR_CHANNELS:
|
35 |
+
raise ValueError("Input image must contain three channel bgr data.")
|
36 |
+
image_rows, image_cols, _ = image.shape
|
37 |
+
|
38 |
+
location = detection.location_data
|
39 |
+
|
40 |
+
# get bounding box if exists.
|
41 |
+
if not location.HasField("relative_bounding_box"):
|
42 |
+
return
|
43 |
+
relative_bounding_box = location.relative_bounding_box
|
44 |
+
rect_start_point = mp_drawing._normalized_to_pixel_coordinates(
|
45 |
+
relative_bounding_box.xmin, relative_bounding_box.ymin, image_cols, image_rows
|
46 |
+
)
|
47 |
+
rect_end_point = mp_drawing._normalized_to_pixel_coordinates(
|
48 |
+
relative_bounding_box.xmin + relative_bounding_box.width,
|
49 |
+
relative_bounding_box.ymin + relative_bounding_box.height,
|
50 |
+
image_cols,
|
51 |
+
image_rows,
|
52 |
+
)
|
53 |
+
|
54 |
+
return *rect_start_point, *rect_end_point
|
55 |
+
|
face_detection/README.md
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
The code for Face Detection in this folder has been taken from the wonderful [face_alignment](https://github.com/1adrianb/face-alignment) repository. This has been modified to take batches of faces at a time.
|
face_detection/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
__author__ = """Adrian Bulat"""
|
4 |
+
__email__ = '[email protected]'
|
5 |
+
__version__ = '1.0.1'
|
6 |
+
|
7 |
+
from .api import FaceAlignment, LandmarksType, NetworkSize
|
face_detection/api.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
from torch.utils.model_zoo import load_url
|
5 |
+
from enum import Enum
|
6 |
+
import numpy as np
|
7 |
+
import cv2
|
8 |
+
try:
|
9 |
+
import urllib.request as request_file
|
10 |
+
except BaseException:
|
11 |
+
import urllib as request_file
|
12 |
+
|
13 |
+
from .models import FAN, ResNetDepth
|
14 |
+
from .utils import *
|
15 |
+
|
16 |
+
|
17 |
+
class LandmarksType(Enum):
|
18 |
+
"""Enum class defining the type of landmarks to detect.
|
19 |
+
|
20 |
+
``_2D`` - the detected points ``(x,y)`` are detected in a 2D space and follow the visible contour of the face
|
21 |
+
``_2halfD`` - this points represent the projection of the 3D points into 3D
|
22 |
+
``_3D`` - detect the points ``(x,y,z)``` in a 3D space
|
23 |
+
|
24 |
+
"""
|
25 |
+
_2D = 1
|
26 |
+
_2halfD = 2
|
27 |
+
_3D = 3
|
28 |
+
|
29 |
+
|
30 |
+
class NetworkSize(Enum):
|
31 |
+
# TINY = 1
|
32 |
+
# SMALL = 2
|
33 |
+
# MEDIUM = 3
|
34 |
+
LARGE = 4
|
35 |
+
|
36 |
+
def __new__(cls, value):
|
37 |
+
member = object.__new__(cls)
|
38 |
+
member._value_ = value
|
39 |
+
return member
|
40 |
+
|
41 |
+
def __int__(self):
|
42 |
+
return self.value
|
43 |
+
|
44 |
+
ROOT = os.path.dirname(os.path.abspath(__file__))
|
45 |
+
|
46 |
+
class FaceAlignment:
|
47 |
+
def __init__(self, landmarks_type, network_size=NetworkSize.LARGE,
|
48 |
+
device='cuda', flip_input=False, face_detector='sfd', verbose=False):
|
49 |
+
self.device = device
|
50 |
+
self.flip_input = flip_input
|
51 |
+
self.landmarks_type = landmarks_type
|
52 |
+
self.verbose = verbose
|
53 |
+
|
54 |
+
network_size = int(network_size)
|
55 |
+
|
56 |
+
if 'cuda' in device:
|
57 |
+
torch.backends.cudnn.benchmark = True
|
58 |
+
|
59 |
+
# Get the face detector
|
60 |
+
face_detector_module = __import__('face_detection.detection.' + face_detector,
|
61 |
+
globals(), locals(), [face_detector], 0)
|
62 |
+
self.face_detector = face_detector_module.FaceDetector(device=device, verbose=verbose)
|
63 |
+
|
64 |
+
def get_detections_for_batch(self, images):
|
65 |
+
images = images[..., ::-1]
|
66 |
+
detected_faces = self.face_detector.detect_from_batch(images.copy())
|
67 |
+
results = []
|
68 |
+
|
69 |
+
for i, d in enumerate(detected_faces):
|
70 |
+
if len(d) == 0:
|
71 |
+
results.append(None)
|
72 |
+
continue
|
73 |
+
d = d[0]
|
74 |
+
d = np.clip(d, 0, None)
|
75 |
+
|
76 |
+
x1, y1, x2, y2 = map(int, d[:-1])
|
77 |
+
results.append((x1, y1, x2, y2))
|
78 |
+
|
79 |
+
return results
|
face_detection/detection/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .core import FaceDetector
|
face_detection/detection/core.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import glob
|
3 |
+
from tqdm import tqdm
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import cv2
|
7 |
+
|
8 |
+
|
9 |
+
class FaceDetector(object):
|
10 |
+
"""An abstract class representing a face detector.
|
11 |
+
|
12 |
+
Any other face detection implementation must subclass it. All subclasses
|
13 |
+
must implement ``detect_from_image``, that return a list of detected
|
14 |
+
bounding boxes. Optionally, for speed considerations detect from path is
|
15 |
+
recommended.
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, device, verbose):
|
19 |
+
self.device = device
|
20 |
+
self.verbose = verbose
|
21 |
+
|
22 |
+
if verbose:
|
23 |
+
if 'cpu' in device:
|
24 |
+
logger = logging.getLogger(__name__)
|
25 |
+
logger.warning("Detection running on CPU, this may be potentially slow.")
|
26 |
+
|
27 |
+
if 'cpu' not in device and 'cuda' not in device:
|
28 |
+
if verbose:
|
29 |
+
logger.error("Expected values for device are: {cpu, cuda} but got: %s", device)
|
30 |
+
raise ValueError
|
31 |
+
|
32 |
+
def detect_from_image(self, tensor_or_path):
|
33 |
+
"""Detects faces in a given image.
|
34 |
+
|
35 |
+
This function detects the faces present in a provided BGR(usually)
|
36 |
+
image. The input can be either the image itself or the path to it.
|
37 |
+
|
38 |
+
Arguments:
|
39 |
+
tensor_or_path {numpy.ndarray, torch.tensor or string} -- the path
|
40 |
+
to an image or the image itself.
|
41 |
+
|
42 |
+
Example::
|
43 |
+
|
44 |
+
>>> path_to_image = 'data/image_01.jpg'
|
45 |
+
... detected_faces = detect_from_image(path_to_image)
|
46 |
+
[A list of bounding boxes (x1, y1, x2, y2)]
|
47 |
+
>>> image = cv2.imread(path_to_image)
|
48 |
+
... detected_faces = detect_from_image(image)
|
49 |
+
[A list of bounding boxes (x1, y1, x2, y2)]
|
50 |
+
|
51 |
+
"""
|
52 |
+
raise NotImplementedError
|
53 |
+
|
54 |
+
def detect_from_directory(self, path, extensions=['.jpg', '.png'], recursive=False, show_progress_bar=True):
|
55 |
+
"""Detects faces from all the images present in a given directory.
|
56 |
+
|
57 |
+
Arguments:
|
58 |
+
path {string} -- a string containing a path that points to the folder containing the images
|
59 |
+
|
60 |
+
Keyword Arguments:
|
61 |
+
extensions {list} -- list of string containing the extensions to be
|
62 |
+
consider in the following format: ``.extension_name`` (default:
|
63 |
+
{['.jpg', '.png']}) recursive {bool} -- option wherever to scan the
|
64 |
+
folder recursively (default: {False}) show_progress_bar {bool} --
|
65 |
+
display a progressbar (default: {True})
|
66 |
+
|
67 |
+
Example:
|
68 |
+
>>> directory = 'data'
|
69 |
+
... detected_faces = detect_from_directory(directory)
|
70 |
+
{A dictionary of [lists containing bounding boxes(x1, y1, x2, y2)]}
|
71 |
+
|
72 |
+
"""
|
73 |
+
if self.verbose:
|
74 |
+
logger = logging.getLogger(__name__)
|
75 |
+
|
76 |
+
if len(extensions) == 0:
|
77 |
+
if self.verbose:
|
78 |
+
logger.error("Expected at list one extension, but none was received.")
|
79 |
+
raise ValueError
|
80 |
+
|
81 |
+
if self.verbose:
|
82 |
+
logger.info("Constructing the list of images.")
|
83 |
+
additional_pattern = '/**/*' if recursive else '/*'
|
84 |
+
files = []
|
85 |
+
for extension in extensions:
|
86 |
+
files.extend(glob.glob(path + additional_pattern + extension, recursive=recursive))
|
87 |
+
|
88 |
+
if self.verbose:
|
89 |
+
logger.info("Finished searching for images. %s images found", len(files))
|
90 |
+
logger.info("Preparing to run the detection.")
|
91 |
+
|
92 |
+
predictions = {}
|
93 |
+
for image_path in tqdm(files, disable=not show_progress_bar):
|
94 |
+
if self.verbose:
|
95 |
+
logger.info("Running the face detector on image: %s", image_path)
|
96 |
+
predictions[image_path] = self.detect_from_image(image_path)
|
97 |
+
|
98 |
+
if self.verbose:
|
99 |
+
logger.info("The detector was successfully run on all %s images", len(files))
|
100 |
+
|
101 |
+
return predictions
|
102 |
+
|
103 |
+
@property
|
104 |
+
def reference_scale(self):
|
105 |
+
raise NotImplementedError
|
106 |
+
|
107 |
+
@property
|
108 |
+
def reference_x_shift(self):
|
109 |
+
raise NotImplementedError
|
110 |
+
|
111 |
+
@property
|
112 |
+
def reference_y_shift(self):
|
113 |
+
raise NotImplementedError
|
114 |
+
|
115 |
+
@staticmethod
|
116 |
+
def tensor_or_path_to_ndarray(tensor_or_path, rgb=True):
|
117 |
+
"""Convert path (represented as a string) or torch.tensor to a numpy.ndarray
|
118 |
+
|
119 |
+
Arguments:
|
120 |
+
tensor_or_path {numpy.ndarray, torch.tensor or string} -- path to the image, or the image itself
|
121 |
+
"""
|
122 |
+
if isinstance(tensor_or_path, str):
|
123 |
+
return cv2.imread(tensor_or_path) if not rgb else cv2.imread(tensor_or_path)[..., ::-1]
|
124 |
+
elif torch.is_tensor(tensor_or_path):
|
125 |
+
# Call cpu in case its coming from cuda
|
126 |
+
return tensor_or_path.cpu().numpy()[..., ::-1].copy() if not rgb else tensor_or_path.cpu().numpy()
|
127 |
+
elif isinstance(tensor_or_path, np.ndarray):
|
128 |
+
return tensor_or_path[..., ::-1].copy() if not rgb else tensor_or_path
|
129 |
+
else:
|
130 |
+
raise TypeError
|
face_detection/detection/sfd/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .sfd_detector import SFDDetector as FaceDetector
|
face_detection/detection/sfd/bbox.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import cv2
|
5 |
+
import random
|
6 |
+
import datetime
|
7 |
+
import time
|
8 |
+
import math
|
9 |
+
import argparse
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
|
13 |
+
try:
|
14 |
+
from iou import IOU
|
15 |
+
except BaseException:
|
16 |
+
# IOU cython speedup 10x
|
17 |
+
def IOU(ax1, ay1, ax2, ay2, bx1, by1, bx2, by2):
|
18 |
+
sa = abs((ax2 - ax1) * (ay2 - ay1))
|
19 |
+
sb = abs((bx2 - bx1) * (by2 - by1))
|
20 |
+
x1, y1 = max(ax1, bx1), max(ay1, by1)
|
21 |
+
x2, y2 = min(ax2, bx2), min(ay2, by2)
|
22 |
+
w = x2 - x1
|
23 |
+
h = y2 - y1
|
24 |
+
if w < 0 or h < 0:
|
25 |
+
return 0.0
|
26 |
+
else:
|
27 |
+
return 1.0 * w * h / (sa + sb - w * h)
|
28 |
+
|
29 |
+
|
30 |
+
def bboxlog(x1, y1, x2, y2, axc, ayc, aww, ahh):
|
31 |
+
xc, yc, ww, hh = (x2 + x1) / 2, (y2 + y1) / 2, x2 - x1, y2 - y1
|
32 |
+
dx, dy = (xc - axc) / aww, (yc - ayc) / ahh
|
33 |
+
dw, dh = math.log(ww / aww), math.log(hh / ahh)
|
34 |
+
return dx, dy, dw, dh
|
35 |
+
|
36 |
+
|
37 |
+
def bboxloginv(dx, dy, dw, dh, axc, ayc, aww, ahh):
|
38 |
+
xc, yc = dx * aww + axc, dy * ahh + ayc
|
39 |
+
ww, hh = math.exp(dw) * aww, math.exp(dh) * ahh
|
40 |
+
x1, x2, y1, y2 = xc - ww / 2, xc + ww / 2, yc - hh / 2, yc + hh / 2
|
41 |
+
return x1, y1, x2, y2
|
42 |
+
|
43 |
+
|
44 |
+
def nms(dets, thresh):
|
45 |
+
if 0 == len(dets):
|
46 |
+
return []
|
47 |
+
x1, y1, x2, y2, scores = dets[:, 0], dets[:, 1], dets[:, 2], dets[:, 3], dets[:, 4]
|
48 |
+
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
49 |
+
order = scores.argsort()[::-1]
|
50 |
+
|
51 |
+
keep = []
|
52 |
+
while order.size > 0:
|
53 |
+
i = order[0]
|
54 |
+
keep.append(i)
|
55 |
+
xx1, yy1 = np.maximum(x1[i], x1[order[1:]]), np.maximum(y1[i], y1[order[1:]])
|
56 |
+
xx2, yy2 = np.minimum(x2[i], x2[order[1:]]), np.minimum(y2[i], y2[order[1:]])
|
57 |
+
|
58 |
+
w, h = np.maximum(0.0, xx2 - xx1 + 1), np.maximum(0.0, yy2 - yy1 + 1)
|
59 |
+
ovr = w * h / (areas[i] + areas[order[1:]] - w * h)
|
60 |
+
|
61 |
+
inds = np.where(ovr <= thresh)[0]
|
62 |
+
order = order[inds + 1]
|
63 |
+
|
64 |
+
return keep
|
65 |
+
|
66 |
+
|
67 |
+
def encode(matched, priors, variances):
|
68 |
+
"""Encode the variances from the priorbox layers into the ground truth boxes
|
69 |
+
we have matched (based on jaccard overlap) with the prior boxes.
|
70 |
+
Args:
|
71 |
+
matched: (tensor) Coords of ground truth for each prior in point-form
|
72 |
+
Shape: [num_priors, 4].
|
73 |
+
priors: (tensor) Prior boxes in center-offset form
|
74 |
+
Shape: [num_priors,4].
|
75 |
+
variances: (list[float]) Variances of priorboxes
|
76 |
+
Return:
|
77 |
+
encoded boxes (tensor), Shape: [num_priors, 4]
|
78 |
+
"""
|
79 |
+
|
80 |
+
# dist b/t match center and prior's center
|
81 |
+
g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2]
|
82 |
+
# encode variance
|
83 |
+
g_cxcy /= (variances[0] * priors[:, 2:])
|
84 |
+
# match wh / prior wh
|
85 |
+
g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
|
86 |
+
g_wh = torch.log(g_wh) / variances[1]
|
87 |
+
# return target for smooth_l1_loss
|
88 |
+
return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4]
|
89 |
+
|
90 |
+
|
91 |
+
def decode(loc, priors, variances):
|
92 |
+
"""Decode locations from predictions using priors to undo
|
93 |
+
the encoding we did for offset regression at train time.
|
94 |
+
Args:
|
95 |
+
loc (tensor): location predictions for loc layers,
|
96 |
+
Shape: [num_priors,4]
|
97 |
+
priors (tensor): Prior boxes in center-offset form.
|
98 |
+
Shape: [num_priors,4].
|
99 |
+
variances: (list[float]) Variances of priorboxes
|
100 |
+
Return:
|
101 |
+
decoded bounding box predictions
|
102 |
+
"""
|
103 |
+
|
104 |
+
boxes = torch.cat((
|
105 |
+
priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
|
106 |
+
priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
|
107 |
+
boxes[:, :2] -= boxes[:, 2:] / 2
|
108 |
+
boxes[:, 2:] += boxes[:, :2]
|
109 |
+
return boxes
|
110 |
+
|
111 |
+
def batch_decode(loc, priors, variances):
|
112 |
+
"""Decode locations from predictions using priors to undo
|
113 |
+
the encoding we did for offset regression at train time.
|
114 |
+
Args:
|
115 |
+
loc (tensor): location predictions for loc layers,
|
116 |
+
Shape: [num_priors,4]
|
117 |
+
priors (tensor): Prior boxes in center-offset form.
|
118 |
+
Shape: [num_priors,4].
|
119 |
+
variances: (list[float]) Variances of priorboxes
|
120 |
+
Return:
|
121 |
+
decoded bounding box predictions
|
122 |
+
"""
|
123 |
+
|
124 |
+
boxes = torch.cat((
|
125 |
+
priors[:, :, :2] + loc[:, :, :2] * variances[0] * priors[:, :, 2:],
|
126 |
+
priors[:, :, 2:] * torch.exp(loc[:, :, 2:] * variances[1])), 2)
|
127 |
+
boxes[:, :, :2] -= boxes[:, :, 2:] / 2
|
128 |
+
boxes[:, :, 2:] += boxes[:, :, :2]
|
129 |
+
return boxes
|
face_detection/detection/sfd/detect.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
import cv2
|
7 |
+
import random
|
8 |
+
import datetime
|
9 |
+
import math
|
10 |
+
import argparse
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
import scipy.io as sio
|
14 |
+
import zipfile
|
15 |
+
from .net_s3fd import s3fd
|
16 |
+
from .bbox import *
|
17 |
+
|
18 |
+
|
19 |
+
def detect(net, img, device):
|
20 |
+
img = img - np.array([104, 117, 123])
|
21 |
+
img = img.transpose(2, 0, 1)
|
22 |
+
img = img.reshape((1,) + img.shape)
|
23 |
+
|
24 |
+
if 'cuda' in device:
|
25 |
+
torch.backends.cudnn.benchmark = True
|
26 |
+
|
27 |
+
img = torch.from_numpy(img).float().to(device)
|
28 |
+
BB, CC, HH, WW = img.size()
|
29 |
+
with torch.no_grad():
|
30 |
+
olist = net(img)
|
31 |
+
|
32 |
+
bboxlist = []
|
33 |
+
for i in range(len(olist) // 2):
|
34 |
+
olist[i * 2] = F.softmax(olist[i * 2], dim=1)
|
35 |
+
olist = [oelem.data.cpu() for oelem in olist]
|
36 |
+
for i in range(len(olist) // 2):
|
37 |
+
ocls, oreg = olist[i * 2], olist[i * 2 + 1]
|
38 |
+
FB, FC, FH, FW = ocls.size() # feature map size
|
39 |
+
stride = 2**(i + 2) # 4,8,16,32,64,128
|
40 |
+
anchor = stride * 4
|
41 |
+
poss = zip(*np.where(ocls[:, 1, :, :] > 0.05))
|
42 |
+
for Iindex, hindex, windex in poss:
|
43 |
+
axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
|
44 |
+
score = ocls[0, 1, hindex, windex]
|
45 |
+
loc = oreg[0, :, hindex, windex].contiguous().view(1, 4)
|
46 |
+
priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]])
|
47 |
+
variances = [0.1, 0.2]
|
48 |
+
box = decode(loc, priors, variances)
|
49 |
+
x1, y1, x2, y2 = box[0] * 1.0
|
50 |
+
# cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1)
|
51 |
+
bboxlist.append([x1, y1, x2, y2, score])
|
52 |
+
bboxlist = np.array(bboxlist)
|
53 |
+
if 0 == len(bboxlist):
|
54 |
+
bboxlist = np.zeros((1, 5))
|
55 |
+
|
56 |
+
return bboxlist
|
57 |
+
|
58 |
+
def batch_detect(net, imgs, device):
|
59 |
+
imgs = imgs - np.array([104, 117, 123])
|
60 |
+
imgs = imgs.transpose(0, 3, 1, 2)
|
61 |
+
|
62 |
+
if 'cuda' in device:
|
63 |
+
torch.backends.cudnn.benchmark = True
|
64 |
+
|
65 |
+
imgs = torch.from_numpy(imgs).float().to(device)
|
66 |
+
BB, CC, HH, WW = imgs.size()
|
67 |
+
with torch.no_grad():
|
68 |
+
olist = net(imgs)
|
69 |
+
|
70 |
+
bboxlist = []
|
71 |
+
for i in range(len(olist) // 2):
|
72 |
+
olist[i * 2] = F.softmax(olist[i * 2], dim=1)
|
73 |
+
olist = [oelem.data.cpu() for oelem in olist]
|
74 |
+
for i in range(len(olist) // 2):
|
75 |
+
ocls, oreg = olist[i * 2], olist[i * 2 + 1]
|
76 |
+
FB, FC, FH, FW = ocls.size() # feature map size
|
77 |
+
stride = 2**(i + 2) # 4,8,16,32,64,128
|
78 |
+
anchor = stride * 4
|
79 |
+
poss = zip(*np.where(ocls[:, 1, :, :] > 0.05))
|
80 |
+
for Iindex, hindex, windex in poss:
|
81 |
+
axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
|
82 |
+
score = ocls[:, 1, hindex, windex]
|
83 |
+
loc = oreg[:, :, hindex, windex].contiguous().view(BB, 1, 4)
|
84 |
+
priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]]).view(1, 1, 4)
|
85 |
+
variances = [0.1, 0.2]
|
86 |
+
box = batch_decode(loc, priors, variances)
|
87 |
+
box = box[:, 0] * 1.0
|
88 |
+
# cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1)
|
89 |
+
bboxlist.append(torch.cat([box, score.unsqueeze(1)], 1).cpu().numpy())
|
90 |
+
bboxlist = np.array(bboxlist)
|
91 |
+
if 0 == len(bboxlist):
|
92 |
+
bboxlist = np.zeros((1, BB, 5))
|
93 |
+
|
94 |
+
return bboxlist
|
95 |
+
|
96 |
+
def flip_detect(net, img, device):
|
97 |
+
img = cv2.flip(img, 1)
|
98 |
+
b = detect(net, img, device)
|
99 |
+
|
100 |
+
bboxlist = np.zeros(b.shape)
|
101 |
+
bboxlist[:, 0] = img.shape[1] - b[:, 2]
|
102 |
+
bboxlist[:, 1] = b[:, 1]
|
103 |
+
bboxlist[:, 2] = img.shape[1] - b[:, 0]
|
104 |
+
bboxlist[:, 3] = b[:, 3]
|
105 |
+
bboxlist[:, 4] = b[:, 4]
|
106 |
+
return bboxlist
|
107 |
+
|
108 |
+
|
109 |
+
def pts_to_bb(pts):
|
110 |
+
min_x, min_y = np.min(pts, axis=0)
|
111 |
+
max_x, max_y = np.max(pts, axis=0)
|
112 |
+
return np.array([min_x, min_y, max_x, max_y])
|
face_detection/detection/sfd/net_s3fd.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class L2Norm(nn.Module):
|
7 |
+
def __init__(self, n_channels, scale=1.0):
|
8 |
+
super(L2Norm, self).__init__()
|
9 |
+
self.n_channels = n_channels
|
10 |
+
self.scale = scale
|
11 |
+
self.eps = 1e-10
|
12 |
+
self.weight = nn.Parameter(torch.Tensor(self.n_channels))
|
13 |
+
self.weight.data *= 0.0
|
14 |
+
self.weight.data += self.scale
|
15 |
+
|
16 |
+
def forward(self, x):
|
17 |
+
norm = x.pow(2).sum(dim=1, keepdim=True).sqrt() + self.eps
|
18 |
+
x = x / norm * self.weight.view(1, -1, 1, 1)
|
19 |
+
return x
|
20 |
+
|
21 |
+
|
22 |
+
class s3fd(nn.Module):
|
23 |
+
def __init__(self):
|
24 |
+
super(s3fd, self).__init__()
|
25 |
+
self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
|
26 |
+
self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
|
27 |
+
|
28 |
+
self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
|
29 |
+
self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
|
30 |
+
|
31 |
+
self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
|
32 |
+
self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
|
33 |
+
self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
|
34 |
+
|
35 |
+
self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
|
36 |
+
self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
37 |
+
self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
38 |
+
|
39 |
+
self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
40 |
+
self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
41 |
+
self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
42 |
+
|
43 |
+
self.fc6 = nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=3)
|
44 |
+
self.fc7 = nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0)
|
45 |
+
|
46 |
+
self.conv6_1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0)
|
47 |
+
self.conv6_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1)
|
48 |
+
|
49 |
+
self.conv7_1 = nn.Conv2d(512, 128, kernel_size=1, stride=1, padding=0)
|
50 |
+
self.conv7_2 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
|
51 |
+
|
52 |
+
self.conv3_3_norm = L2Norm(256, scale=10)
|
53 |
+
self.conv4_3_norm = L2Norm(512, scale=8)
|
54 |
+
self.conv5_3_norm = L2Norm(512, scale=5)
|
55 |
+
|
56 |
+
self.conv3_3_norm_mbox_conf = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
|
57 |
+
self.conv3_3_norm_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
|
58 |
+
self.conv4_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
|
59 |
+
self.conv4_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
|
60 |
+
self.conv5_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
|
61 |
+
self.conv5_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
|
62 |
+
|
63 |
+
self.fc7_mbox_conf = nn.Conv2d(1024, 2, kernel_size=3, stride=1, padding=1)
|
64 |
+
self.fc7_mbox_loc = nn.Conv2d(1024, 4, kernel_size=3, stride=1, padding=1)
|
65 |
+
self.conv6_2_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
|
66 |
+
self.conv6_2_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
|
67 |
+
self.conv7_2_mbox_conf = nn.Conv2d(256, 2, kernel_size=3, stride=1, padding=1)
|
68 |
+
self.conv7_2_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
|
69 |
+
|
70 |
+
def forward(self, x):
|
71 |
+
h = F.relu(self.conv1_1(x))
|
72 |
+
h = F.relu(self.conv1_2(h))
|
73 |
+
h = F.max_pool2d(h, 2, 2)
|
74 |
+
|
75 |
+
h = F.relu(self.conv2_1(h))
|
76 |
+
h = F.relu(self.conv2_2(h))
|
77 |
+
h = F.max_pool2d(h, 2, 2)
|
78 |
+
|
79 |
+
h = F.relu(self.conv3_1(h))
|
80 |
+
h = F.relu(self.conv3_2(h))
|
81 |
+
h = F.relu(self.conv3_3(h))
|
82 |
+
f3_3 = h
|
83 |
+
h = F.max_pool2d(h, 2, 2)
|
84 |
+
|
85 |
+
h = F.relu(self.conv4_1(h))
|
86 |
+
h = F.relu(self.conv4_2(h))
|
87 |
+
h = F.relu(self.conv4_3(h))
|
88 |
+
f4_3 = h
|
89 |
+
h = F.max_pool2d(h, 2, 2)
|
90 |
+
|
91 |
+
h = F.relu(self.conv5_1(h))
|
92 |
+
h = F.relu(self.conv5_2(h))
|
93 |
+
h = F.relu(self.conv5_3(h))
|
94 |
+
f5_3 = h
|
95 |
+
h = F.max_pool2d(h, 2, 2)
|
96 |
+
|
97 |
+
h = F.relu(self.fc6(h))
|
98 |
+
h = F.relu(self.fc7(h))
|
99 |
+
ffc7 = h
|
100 |
+
h = F.relu(self.conv6_1(h))
|
101 |
+
h = F.relu(self.conv6_2(h))
|
102 |
+
f6_2 = h
|
103 |
+
h = F.relu(self.conv7_1(h))
|
104 |
+
h = F.relu(self.conv7_2(h))
|
105 |
+
f7_2 = h
|
106 |
+
|
107 |
+
f3_3 = self.conv3_3_norm(f3_3)
|
108 |
+
f4_3 = self.conv4_3_norm(f4_3)
|
109 |
+
f5_3 = self.conv5_3_norm(f5_3)
|
110 |
+
|
111 |
+
cls1 = self.conv3_3_norm_mbox_conf(f3_3)
|
112 |
+
reg1 = self.conv3_3_norm_mbox_loc(f3_3)
|
113 |
+
cls2 = self.conv4_3_norm_mbox_conf(f4_3)
|
114 |
+
reg2 = self.conv4_3_norm_mbox_loc(f4_3)
|
115 |
+
cls3 = self.conv5_3_norm_mbox_conf(f5_3)
|
116 |
+
reg3 = self.conv5_3_norm_mbox_loc(f5_3)
|
117 |
+
cls4 = self.fc7_mbox_conf(ffc7)
|
118 |
+
reg4 = self.fc7_mbox_loc(ffc7)
|
119 |
+
cls5 = self.conv6_2_mbox_conf(f6_2)
|
120 |
+
reg5 = self.conv6_2_mbox_loc(f6_2)
|
121 |
+
cls6 = self.conv7_2_mbox_conf(f7_2)
|
122 |
+
reg6 = self.conv7_2_mbox_loc(f7_2)
|
123 |
+
|
124 |
+
# max-out background label
|
125 |
+
chunk = torch.chunk(cls1, 4, 1)
|
126 |
+
bmax = torch.max(torch.max(chunk[0], chunk[1]), chunk[2])
|
127 |
+
cls1 = torch.cat([bmax, chunk[3]], dim=1)
|
128 |
+
|
129 |
+
return [cls1, reg1, cls2, reg2, cls3, reg3, cls4, reg4, cls5, reg5, cls6, reg6]
|
face_detection/detection/sfd/sfd_detector.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
from torch.utils.model_zoo import load_url
|
4 |
+
|
5 |
+
from ..core import FaceDetector
|
6 |
+
|
7 |
+
from .net_s3fd import s3fd
|
8 |
+
from .bbox import *
|
9 |
+
from .detect import *
|
10 |
+
|
11 |
+
models_urls = {
|
12 |
+
's3fd': 'https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth',
|
13 |
+
}
|
14 |
+
|
15 |
+
|
16 |
+
class SFDDetector(FaceDetector):
|
17 |
+
@classmethod
|
18 |
+
def load_model(cls, device):
|
19 |
+
path_to_detector = os.path.join(os.path.dirname(os.path.abspath(__file__)), 's3fd.pth')
|
20 |
+
|
21 |
+
# Initialise the face detector
|
22 |
+
if not os.path.isfile(path_to_detector):
|
23 |
+
model_weights = load_url(models_urls['s3fd'])
|
24 |
+
else:
|
25 |
+
model_weights = torch.load(path_to_detector)
|
26 |
+
|
27 |
+
cls.face_detector = s3fd()
|
28 |
+
cls.face_detector.load_state_dict(model_weights)
|
29 |
+
cls.face_detector.to(device)
|
30 |
+
cls.face_detector.eval()
|
31 |
+
|
32 |
+
def detect_from_image(self, tensor_or_path):
|
33 |
+
image = self.tensor_or_path_to_ndarray(tensor_or_path)
|
34 |
+
|
35 |
+
bboxlist = detect(self.face_detector, image, device=self.device)
|
36 |
+
keep = nms(bboxlist, 0.3)
|
37 |
+
bboxlist = bboxlist[keep, :]
|
38 |
+
bboxlist = [x for x in bboxlist if x[-1] > 0.5]
|
39 |
+
|
40 |
+
return bboxlist
|
41 |
+
|
42 |
+
def detect_from_batch(self, images):
|
43 |
+
bboxlists = batch_detect(self.face_detector, images, device=self.device)
|
44 |
+
keeps = [nms(bboxlists[:, i, :], 0.3) for i in range(bboxlists.shape[1])]
|
45 |
+
bboxlists = [bboxlists[keep, i, :] for i, keep in enumerate(keeps)]
|
46 |
+
bboxlists = [[x for x in bboxlist if x[-1] > 0.5] for bboxlist in bboxlists]
|
47 |
+
|
48 |
+
return bboxlists
|
49 |
+
|
50 |
+
@property
|
51 |
+
def reference_scale(self):
|
52 |
+
return 195
|
53 |
+
|
54 |
+
@property
|
55 |
+
def reference_x_shift(self):
|
56 |
+
return 0
|
57 |
+
|
58 |
+
@property
|
59 |
+
def reference_y_shift(self):
|
60 |
+
return 0
|
face_detection/models.py
ADDED
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import math
|
5 |
+
|
6 |
+
|
7 |
+
def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False):
|
8 |
+
"3x3 convolution with padding"
|
9 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3,
|
10 |
+
stride=strd, padding=padding, bias=bias)
|
11 |
+
|
12 |
+
|
13 |
+
class ConvBlock(nn.Module):
|
14 |
+
def __init__(self, in_planes, out_planes):
|
15 |
+
super(ConvBlock, self).__init__()
|
16 |
+
self.bn1 = nn.BatchNorm2d(in_planes)
|
17 |
+
self.conv1 = conv3x3(in_planes, int(out_planes / 2))
|
18 |
+
self.bn2 = nn.BatchNorm2d(int(out_planes / 2))
|
19 |
+
self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4))
|
20 |
+
self.bn3 = nn.BatchNorm2d(int(out_planes / 4))
|
21 |
+
self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4))
|
22 |
+
|
23 |
+
if in_planes != out_planes:
|
24 |
+
self.downsample = nn.Sequential(
|
25 |
+
nn.BatchNorm2d(in_planes),
|
26 |
+
nn.ReLU(True),
|
27 |
+
nn.Conv2d(in_planes, out_planes,
|
28 |
+
kernel_size=1, stride=1, bias=False),
|
29 |
+
)
|
30 |
+
else:
|
31 |
+
self.downsample = None
|
32 |
+
|
33 |
+
def forward(self, x):
|
34 |
+
residual = x
|
35 |
+
|
36 |
+
out1 = self.bn1(x)
|
37 |
+
out1 = F.relu(out1, True)
|
38 |
+
out1 = self.conv1(out1)
|
39 |
+
|
40 |
+
out2 = self.bn2(out1)
|
41 |
+
out2 = F.relu(out2, True)
|
42 |
+
out2 = self.conv2(out2)
|
43 |
+
|
44 |
+
out3 = self.bn3(out2)
|
45 |
+
out3 = F.relu(out3, True)
|
46 |
+
out3 = self.conv3(out3)
|
47 |
+
|
48 |
+
out3 = torch.cat((out1, out2, out3), 1)
|
49 |
+
|
50 |
+
if self.downsample is not None:
|
51 |
+
residual = self.downsample(residual)
|
52 |
+
|
53 |
+
out3 += residual
|
54 |
+
|
55 |
+
return out3
|
56 |
+
|
57 |
+
|
58 |
+
class Bottleneck(nn.Module):
|
59 |
+
|
60 |
+
expansion = 4
|
61 |
+
|
62 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
63 |
+
super(Bottleneck, self).__init__()
|
64 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
65 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
66 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
|
67 |
+
padding=1, bias=False)
|
68 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
69 |
+
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
70 |
+
self.bn3 = nn.BatchNorm2d(planes * 4)
|
71 |
+
self.relu = nn.ReLU(inplace=True)
|
72 |
+
self.downsample = downsample
|
73 |
+
self.stride = stride
|
74 |
+
|
75 |
+
def forward(self, x):
|
76 |
+
residual = x
|
77 |
+
|
78 |
+
out = self.conv1(x)
|
79 |
+
out = self.bn1(out)
|
80 |
+
out = self.relu(out)
|
81 |
+
|
82 |
+
out = self.conv2(out)
|
83 |
+
out = self.bn2(out)
|
84 |
+
out = self.relu(out)
|
85 |
+
|
86 |
+
out = self.conv3(out)
|
87 |
+
out = self.bn3(out)
|
88 |
+
|
89 |
+
if self.downsample is not None:
|
90 |
+
residual = self.downsample(x)
|
91 |
+
|
92 |
+
out += residual
|
93 |
+
out = self.relu(out)
|
94 |
+
|
95 |
+
return out
|
96 |
+
|
97 |
+
|
98 |
+
class HourGlass(nn.Module):
|
99 |
+
def __init__(self, num_modules, depth, num_features):
|
100 |
+
super(HourGlass, self).__init__()
|
101 |
+
self.num_modules = num_modules
|
102 |
+
self.depth = depth
|
103 |
+
self.features = num_features
|
104 |
+
|
105 |
+
self._generate_network(self.depth)
|
106 |
+
|
107 |
+
def _generate_network(self, level):
|
108 |
+
self.add_module('b1_' + str(level), ConvBlock(self.features, self.features))
|
109 |
+
|
110 |
+
self.add_module('b2_' + str(level), ConvBlock(self.features, self.features))
|
111 |
+
|
112 |
+
if level > 1:
|
113 |
+
self._generate_network(level - 1)
|
114 |
+
else:
|
115 |
+
self.add_module('b2_plus_' + str(level), ConvBlock(self.features, self.features))
|
116 |
+
|
117 |
+
self.add_module('b3_' + str(level), ConvBlock(self.features, self.features))
|
118 |
+
|
119 |
+
def _forward(self, level, inp):
|
120 |
+
# Upper branch
|
121 |
+
up1 = inp
|
122 |
+
up1 = self._modules['b1_' + str(level)](up1)
|
123 |
+
|
124 |
+
# Lower branch
|
125 |
+
low1 = F.avg_pool2d(inp, 2, stride=2)
|
126 |
+
low1 = self._modules['b2_' + str(level)](low1)
|
127 |
+
|
128 |
+
if level > 1:
|
129 |
+
low2 = self._forward(level - 1, low1)
|
130 |
+
else:
|
131 |
+
low2 = low1
|
132 |
+
low2 = self._modules['b2_plus_' + str(level)](low2)
|
133 |
+
|
134 |
+
low3 = low2
|
135 |
+
low3 = self._modules['b3_' + str(level)](low3)
|
136 |
+
|
137 |
+
up2 = F.interpolate(low3, scale_factor=2, mode='nearest')
|
138 |
+
|
139 |
+
return up1 + up2
|
140 |
+
|
141 |
+
def forward(self, x):
|
142 |
+
return self._forward(self.depth, x)
|
143 |
+
|
144 |
+
|
145 |
+
class FAN(nn.Module):
|
146 |
+
|
147 |
+
def __init__(self, num_modules=1):
|
148 |
+
super(FAN, self).__init__()
|
149 |
+
self.num_modules = num_modules
|
150 |
+
|
151 |
+
# Base part
|
152 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
|
153 |
+
self.bn1 = nn.BatchNorm2d(64)
|
154 |
+
self.conv2 = ConvBlock(64, 128)
|
155 |
+
self.conv3 = ConvBlock(128, 128)
|
156 |
+
self.conv4 = ConvBlock(128, 256)
|
157 |
+
|
158 |
+
# Stacking part
|
159 |
+
for hg_module in range(self.num_modules):
|
160 |
+
self.add_module('m' + str(hg_module), HourGlass(1, 4, 256))
|
161 |
+
self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256))
|
162 |
+
self.add_module('conv_last' + str(hg_module),
|
163 |
+
nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
|
164 |
+
self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256))
|
165 |
+
self.add_module('l' + str(hg_module), nn.Conv2d(256,
|
166 |
+
68, kernel_size=1, stride=1, padding=0))
|
167 |
+
|
168 |
+
if hg_module < self.num_modules - 1:
|
169 |
+
self.add_module(
|
170 |
+
'bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
|
171 |
+
self.add_module('al' + str(hg_module), nn.Conv2d(68,
|
172 |
+
256, kernel_size=1, stride=1, padding=0))
|
173 |
+
|
174 |
+
def forward(self, x):
|
175 |
+
x = F.relu(self.bn1(self.conv1(x)), True)
|
176 |
+
x = F.avg_pool2d(self.conv2(x), 2, stride=2)
|
177 |
+
x = self.conv3(x)
|
178 |
+
x = self.conv4(x)
|
179 |
+
|
180 |
+
previous = x
|
181 |
+
|
182 |
+
outputs = []
|
183 |
+
for i in range(self.num_modules):
|
184 |
+
hg = self._modules['m' + str(i)](previous)
|
185 |
+
|
186 |
+
ll = hg
|
187 |
+
ll = self._modules['top_m_' + str(i)](ll)
|
188 |
+
|
189 |
+
ll = F.relu(self._modules['bn_end' + str(i)]
|
190 |
+
(self._modules['conv_last' + str(i)](ll)), True)
|
191 |
+
|
192 |
+
# Predict heatmaps
|
193 |
+
tmp_out = self._modules['l' + str(i)](ll)
|
194 |
+
outputs.append(tmp_out)
|
195 |
+
|
196 |
+
if i < self.num_modules - 1:
|
197 |
+
ll = self._modules['bl' + str(i)](ll)
|
198 |
+
tmp_out_ = self._modules['al' + str(i)](tmp_out)
|
199 |
+
previous = previous + ll + tmp_out_
|
200 |
+
|
201 |
+
return outputs
|
202 |
+
|
203 |
+
|
204 |
+
class ResNetDepth(nn.Module):
|
205 |
+
|
206 |
+
def __init__(self, block=Bottleneck, layers=[3, 8, 36, 3], num_classes=68):
|
207 |
+
self.inplanes = 64
|
208 |
+
super(ResNetDepth, self).__init__()
|
209 |
+
self.conv1 = nn.Conv2d(3 + 68, 64, kernel_size=7, stride=2, padding=3,
|
210 |
+
bias=False)
|
211 |
+
self.bn1 = nn.BatchNorm2d(64)
|
212 |
+
self.relu = nn.ReLU(inplace=True)
|
213 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
214 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
215 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
216 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
217 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
218 |
+
self.avgpool = nn.AvgPool2d(7)
|
219 |
+
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
220 |
+
|
221 |
+
for m in self.modules():
|
222 |
+
if isinstance(m, nn.Conv2d):
|
223 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
224 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
225 |
+
elif isinstance(m, nn.BatchNorm2d):
|
226 |
+
m.weight.data.fill_(1)
|
227 |
+
m.bias.data.zero_()
|
228 |
+
|
229 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
230 |
+
downsample = None
|
231 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
232 |
+
downsample = nn.Sequential(
|
233 |
+
nn.Conv2d(self.inplanes, planes * block.expansion,
|
234 |
+
kernel_size=1, stride=stride, bias=False),
|
235 |
+
nn.BatchNorm2d(planes * block.expansion),
|
236 |
+
)
|
237 |
+
|
238 |
+
layers = []
|
239 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
240 |
+
self.inplanes = planes * block.expansion
|
241 |
+
for i in range(1, blocks):
|
242 |
+
layers.append(block(self.inplanes, planes))
|
243 |
+
|
244 |
+
return nn.Sequential(*layers)
|
245 |
+
|
246 |
+
def forward(self, x):
|
247 |
+
x = self.conv1(x)
|
248 |
+
x = self.bn1(x)
|
249 |
+
x = self.relu(x)
|
250 |
+
x = self.maxpool(x)
|
251 |
+
|
252 |
+
x = self.layer1(x)
|
253 |
+
x = self.layer2(x)
|
254 |
+
x = self.layer3(x)
|
255 |
+
x = self.layer4(x)
|
256 |
+
|
257 |
+
x = self.avgpool(x)
|
258 |
+
x = x.view(x.size(0), -1)
|
259 |
+
x = self.fc(x)
|
260 |
+
|
261 |
+
return x
|
face_detection/utils.py
ADDED
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import time
|
5 |
+
import torch
|
6 |
+
import math
|
7 |
+
import numpy as np
|
8 |
+
import cv2
|
9 |
+
|
10 |
+
|
11 |
+
def _gaussian(
|
12 |
+
size=3, sigma=0.25, amplitude=1, normalize=False, width=None,
|
13 |
+
height=None, sigma_horz=None, sigma_vert=None, mean_horz=0.5,
|
14 |
+
mean_vert=0.5):
|
15 |
+
# handle some defaults
|
16 |
+
if width is None:
|
17 |
+
width = size
|
18 |
+
if height is None:
|
19 |
+
height = size
|
20 |
+
if sigma_horz is None:
|
21 |
+
sigma_horz = sigma
|
22 |
+
if sigma_vert is None:
|
23 |
+
sigma_vert = sigma
|
24 |
+
center_x = mean_horz * width + 0.5
|
25 |
+
center_y = mean_vert * height + 0.5
|
26 |
+
gauss = np.empty((height, width), dtype=np.float32)
|
27 |
+
# generate kernel
|
28 |
+
for i in range(height):
|
29 |
+
for j in range(width):
|
30 |
+
gauss[i][j] = amplitude * math.exp(-(math.pow((j + 1 - center_x) / (
|
31 |
+
sigma_horz * width), 2) / 2.0 + math.pow((i + 1 - center_y) / (sigma_vert * height), 2) / 2.0))
|
32 |
+
if normalize:
|
33 |
+
gauss = gauss / np.sum(gauss)
|
34 |
+
return gauss
|
35 |
+
|
36 |
+
|
37 |
+
def draw_gaussian(image, point, sigma):
|
38 |
+
# Check if the gaussian is inside
|
39 |
+
ul = [math.floor(point[0] - 3 * sigma), math.floor(point[1] - 3 * sigma)]
|
40 |
+
br = [math.floor(point[0] + 3 * sigma), math.floor(point[1] + 3 * sigma)]
|
41 |
+
if (ul[0] > image.shape[1] or ul[1] > image.shape[0] or br[0] < 1 or br[1] < 1):
|
42 |
+
return image
|
43 |
+
size = 6 * sigma + 1
|
44 |
+
g = _gaussian(size)
|
45 |
+
g_x = [int(max(1, -ul[0])), int(min(br[0], image.shape[1])) - int(max(1, ul[0])) + int(max(1, -ul[0]))]
|
46 |
+
g_y = [int(max(1, -ul[1])), int(min(br[1], image.shape[0])) - int(max(1, ul[1])) + int(max(1, -ul[1]))]
|
47 |
+
img_x = [int(max(1, ul[0])), int(min(br[0], image.shape[1]))]
|
48 |
+
img_y = [int(max(1, ul[1])), int(min(br[1], image.shape[0]))]
|
49 |
+
assert (g_x[0] > 0 and g_y[1] > 0)
|
50 |
+
image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]
|
51 |
+
] = image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]] + g[g_y[0] - 1:g_y[1], g_x[0] - 1:g_x[1]]
|
52 |
+
image[image > 1] = 1
|
53 |
+
return image
|
54 |
+
|
55 |
+
|
56 |
+
def transform(point, center, scale, resolution, invert=False):
|
57 |
+
"""Generate and affine transformation matrix.
|
58 |
+
|
59 |
+
Given a set of points, a center, a scale and a targer resolution, the
|
60 |
+
function generates and affine transformation matrix. If invert is ``True``
|
61 |
+
it will produce the inverse transformation.
|
62 |
+
|
63 |
+
Arguments:
|
64 |
+
point {torch.tensor} -- the input 2D point
|
65 |
+
center {torch.tensor or numpy.array} -- the center around which to perform the transformations
|
66 |
+
scale {float} -- the scale of the face/object
|
67 |
+
resolution {float} -- the output resolution
|
68 |
+
|
69 |
+
Keyword Arguments:
|
70 |
+
invert {bool} -- define wherever the function should produce the direct or the
|
71 |
+
inverse transformation matrix (default: {False})
|
72 |
+
"""
|
73 |
+
_pt = torch.ones(3)
|
74 |
+
_pt[0] = point[0]
|
75 |
+
_pt[1] = point[1]
|
76 |
+
|
77 |
+
h = 200.0 * scale
|
78 |
+
t = torch.eye(3)
|
79 |
+
t[0, 0] = resolution / h
|
80 |
+
t[1, 1] = resolution / h
|
81 |
+
t[0, 2] = resolution * (-center[0] / h + 0.5)
|
82 |
+
t[1, 2] = resolution * (-center[1] / h + 0.5)
|
83 |
+
|
84 |
+
if invert:
|
85 |
+
t = torch.inverse(t)
|
86 |
+
|
87 |
+
new_point = (torch.matmul(t, _pt))[0:2]
|
88 |
+
|
89 |
+
return new_point.int()
|
90 |
+
|
91 |
+
|
92 |
+
def crop(image, center, scale, resolution=256.0):
|
93 |
+
"""Center crops an image or set of heatmaps
|
94 |
+
|
95 |
+
Arguments:
|
96 |
+
image {numpy.array} -- an rgb image
|
97 |
+
center {numpy.array} -- the center of the object, usually the same as of the bounding box
|
98 |
+
scale {float} -- scale of the face
|
99 |
+
|
100 |
+
Keyword Arguments:
|
101 |
+
resolution {float} -- the size of the output cropped image (default: {256.0})
|
102 |
+
|
103 |
+
Returns:
|
104 |
+
[type] -- [description]
|
105 |
+
""" # Crop around the center point
|
106 |
+
""" Crops the image around the center. Input is expected to be an np.ndarray """
|
107 |
+
ul = transform([1, 1], center, scale, resolution, True)
|
108 |
+
br = transform([resolution, resolution], center, scale, resolution, True)
|
109 |
+
# pad = math.ceil(torch.norm((ul - br).float()) / 2.0 - (br[0] - ul[0]) / 2.0)
|
110 |
+
if image.ndim > 2:
|
111 |
+
newDim = np.array([br[1] - ul[1], br[0] - ul[0],
|
112 |
+
image.shape[2]], dtype=np.int32)
|
113 |
+
newImg = np.zeros(newDim, dtype=np.uint8)
|
114 |
+
else:
|
115 |
+
newDim = np.array([br[1] - ul[1], br[0] - ul[0]], dtype=np.int)
|
116 |
+
newImg = np.zeros(newDim, dtype=np.uint8)
|
117 |
+
ht = image.shape[0]
|
118 |
+
wd = image.shape[1]
|
119 |
+
newX = np.array(
|
120 |
+
[max(1, -ul[0] + 1), min(br[0], wd) - ul[0]], dtype=np.int32)
|
121 |
+
newY = np.array(
|
122 |
+
[max(1, -ul[1] + 1), min(br[1], ht) - ul[1]], dtype=np.int32)
|
123 |
+
oldX = np.array([max(1, ul[0] + 1), min(br[0], wd)], dtype=np.int32)
|
124 |
+
oldY = np.array([max(1, ul[1] + 1), min(br[1], ht)], dtype=np.int32)
|
125 |
+
newImg[newY[0] - 1:newY[1], newX[0] - 1:newX[1]
|
126 |
+
] = image[oldY[0] - 1:oldY[1], oldX[0] - 1:oldX[1], :]
|
127 |
+
newImg = cv2.resize(newImg, dsize=(int(resolution), int(resolution)),
|
128 |
+
interpolation=cv2.INTER_LINEAR)
|
129 |
+
return newImg
|
130 |
+
|
131 |
+
|
132 |
+
def get_preds_fromhm(hm, center=None, scale=None):
|
133 |
+
"""Obtain (x,y) coordinates given a set of N heatmaps. If the center
|
134 |
+
and the scale is provided the function will return the points also in
|
135 |
+
the original coordinate frame.
|
136 |
+
|
137 |
+
Arguments:
|
138 |
+
hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H]
|
139 |
+
|
140 |
+
Keyword Arguments:
|
141 |
+
center {torch.tensor} -- the center of the bounding box (default: {None})
|
142 |
+
scale {float} -- face scale (default: {None})
|
143 |
+
"""
|
144 |
+
max, idx = torch.max(
|
145 |
+
hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
|
146 |
+
idx += 1
|
147 |
+
preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
|
148 |
+
preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
|
149 |
+
preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
|
150 |
+
|
151 |
+
for i in range(preds.size(0)):
|
152 |
+
for j in range(preds.size(1)):
|
153 |
+
hm_ = hm[i, j, :]
|
154 |
+
pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
|
155 |
+
if pX > 0 and pX < 63 and pY > 0 and pY < 63:
|
156 |
+
diff = torch.FloatTensor(
|
157 |
+
[hm_[pY, pX + 1] - hm_[pY, pX - 1],
|
158 |
+
hm_[pY + 1, pX] - hm_[pY - 1, pX]])
|
159 |
+
preds[i, j].add_(diff.sign_().mul_(.25))
|
160 |
+
|
161 |
+
preds.add_(-.5)
|
162 |
+
|
163 |
+
preds_orig = torch.zeros(preds.size())
|
164 |
+
if center is not None and scale is not None:
|
165 |
+
for i in range(hm.size(0)):
|
166 |
+
for j in range(hm.size(1)):
|
167 |
+
preds_orig[i, j] = transform(
|
168 |
+
preds[i, j], center, scale, hm.size(2), True)
|
169 |
+
|
170 |
+
return preds, preds_orig
|
171 |
+
|
172 |
+
def get_preds_fromhm_batch(hm, centers=None, scales=None):
|
173 |
+
"""Obtain (x,y) coordinates given a set of N heatmaps. If the centers
|
174 |
+
and the scales is provided the function will return the points also in
|
175 |
+
the original coordinate frame.
|
176 |
+
|
177 |
+
Arguments:
|
178 |
+
hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H]
|
179 |
+
|
180 |
+
Keyword Arguments:
|
181 |
+
centers {torch.tensor} -- the centers of the bounding box (default: {None})
|
182 |
+
scales {float} -- face scales (default: {None})
|
183 |
+
"""
|
184 |
+
max, idx = torch.max(
|
185 |
+
hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
|
186 |
+
idx += 1
|
187 |
+
preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
|
188 |
+
preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
|
189 |
+
preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
|
190 |
+
|
191 |
+
for i in range(preds.size(0)):
|
192 |
+
for j in range(preds.size(1)):
|
193 |
+
hm_ = hm[i, j, :]
|
194 |
+
pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
|
195 |
+
if pX > 0 and pX < 63 and pY > 0 and pY < 63:
|
196 |
+
diff = torch.FloatTensor(
|
197 |
+
[hm_[pY, pX + 1] - hm_[pY, pX - 1],
|
198 |
+
hm_[pY + 1, pX] - hm_[pY - 1, pX]])
|
199 |
+
preds[i, j].add_(diff.sign_().mul_(.25))
|
200 |
+
|
201 |
+
preds.add_(-.5)
|
202 |
+
|
203 |
+
preds_orig = torch.zeros(preds.size())
|
204 |
+
if centers is not None and scales is not None:
|
205 |
+
for i in range(hm.size(0)):
|
206 |
+
for j in range(hm.size(1)):
|
207 |
+
preds_orig[i, j] = transform(
|
208 |
+
preds[i, j], centers[i], scales[i], hm.size(2), True)
|
209 |
+
|
210 |
+
return preds, preds_orig
|
211 |
+
|
212 |
+
def shuffle_lr(parts, pairs=None):
|
213 |
+
"""Shuffle the points left-right according to the axis of symmetry
|
214 |
+
of the object.
|
215 |
+
|
216 |
+
Arguments:
|
217 |
+
parts {torch.tensor} -- a 3D or 4D object containing the
|
218 |
+
heatmaps.
|
219 |
+
|
220 |
+
Keyword Arguments:
|
221 |
+
pairs {list of integers} -- [order of the flipped points] (default: {None})
|
222 |
+
"""
|
223 |
+
if pairs is None:
|
224 |
+
pairs = [16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0,
|
225 |
+
26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 27, 28, 29, 30, 35,
|
226 |
+
34, 33, 32, 31, 45, 44, 43, 42, 47, 46, 39, 38, 37, 36, 41,
|
227 |
+
40, 54, 53, 52, 51, 50, 49, 48, 59, 58, 57, 56, 55, 64, 63,
|
228 |
+
62, 61, 60, 67, 66, 65]
|
229 |
+
if parts.ndimension() == 3:
|
230 |
+
parts = parts[pairs, ...]
|
231 |
+
else:
|
232 |
+
parts = parts[:, pairs, ...]
|
233 |
+
|
234 |
+
return parts
|
235 |
+
|
236 |
+
|
237 |
+
def flip(tensor, is_label=False):
|
238 |
+
"""Flip an image or a set of heatmaps left-right
|
239 |
+
|
240 |
+
Arguments:
|
241 |
+
tensor {numpy.array or torch.tensor} -- [the input image or heatmaps]
|
242 |
+
|
243 |
+
Keyword Arguments:
|
244 |
+
is_label {bool} -- [denote wherever the input is an image or a set of heatmaps ] (default: {False})
|
245 |
+
"""
|
246 |
+
if not torch.is_tensor(tensor):
|
247 |
+
tensor = torch.from_numpy(tensor)
|
248 |
+
|
249 |
+
if is_label:
|
250 |
+
tensor = shuffle_lr(tensor).flip(tensor.ndimension() - 1)
|
251 |
+
else:
|
252 |
+
tensor = tensor.flip(tensor.ndimension() - 1)
|
253 |
+
|
254 |
+
return tensor
|
255 |
+
|
256 |
+
# From pyzolib/paths.py (https://bitbucket.org/pyzo/pyzolib/src/tip/paths.py)
|
257 |
+
|
258 |
+
|
259 |
+
def appdata_dir(appname=None, roaming=False):
|
260 |
+
""" appdata_dir(appname=None, roaming=False)
|
261 |
+
|
262 |
+
Get the path to the application directory, where applications are allowed
|
263 |
+
to write user specific files (e.g. configurations). For non-user specific
|
264 |
+
data, consider using common_appdata_dir().
|
265 |
+
If appname is given, a subdir is appended (and created if necessary).
|
266 |
+
If roaming is True, will prefer a roaming directory (Windows Vista/7).
|
267 |
+
"""
|
268 |
+
|
269 |
+
# Define default user directory
|
270 |
+
userDir = os.getenv('FACEALIGNMENT_USERDIR', None)
|
271 |
+
if userDir is None:
|
272 |
+
userDir = os.path.expanduser('~')
|
273 |
+
if not os.path.isdir(userDir): # pragma: no cover
|
274 |
+
userDir = '/var/tmp' # issue #54
|
275 |
+
|
276 |
+
# Get system app data dir
|
277 |
+
path = None
|
278 |
+
if sys.platform.startswith('win'):
|
279 |
+
path1, path2 = os.getenv('LOCALAPPDATA'), os.getenv('APPDATA')
|
280 |
+
path = (path2 or path1) if roaming else (path1 or path2)
|
281 |
+
elif sys.platform.startswith('darwin'):
|
282 |
+
path = os.path.join(userDir, 'Library', 'Application Support')
|
283 |
+
# On Linux and as fallback
|
284 |
+
if not (path and os.path.isdir(path)):
|
285 |
+
path = userDir
|
286 |
+
|
287 |
+
# Maybe we should store things local to the executable (in case of a
|
288 |
+
# portable distro or a frozen application that wants to be portable)
|
289 |
+
prefix = sys.prefix
|
290 |
+
if getattr(sys, 'frozen', None):
|
291 |
+
prefix = os.path.abspath(os.path.dirname(sys.executable))
|
292 |
+
for reldir in ('settings', '../settings'):
|
293 |
+
localpath = os.path.abspath(os.path.join(prefix, reldir))
|
294 |
+
if os.path.isdir(localpath): # pragma: no cover
|
295 |
+
try:
|
296 |
+
open(os.path.join(localpath, 'test.write'), 'wb').close()
|
297 |
+
os.remove(os.path.join(localpath, 'test.write'))
|
298 |
+
except IOError:
|
299 |
+
pass # We cannot write in this directory
|
300 |
+
else:
|
301 |
+
path = localpath
|
302 |
+
break
|
303 |
+
|
304 |
+
# Get path specific for this app
|
305 |
+
if appname:
|
306 |
+
if path == userDir:
|
307 |
+
appname = '.' + appname.lstrip('.') # Make it a hidden directory
|
308 |
+
path = os.path.join(path, appname)
|
309 |
+
if not os.path.isdir(path): # pragma: no cover
|
310 |
+
os.mkdir(path)
|
311 |
+
|
312 |
+
# Done
|
313 |
+
return path
|
filelists/README.md
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Place LRS2 (and any other) filelists here for training.
|
hparams.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from glob import glob
|
2 |
+
import os
|
3 |
+
|
4 |
+
def get_image_list(data_root, split):
|
5 |
+
filelist = []
|
6 |
+
|
7 |
+
with open('filelists/{}.txt'.format(split)) as f:
|
8 |
+
for line in f:
|
9 |
+
line = line.strip()
|
10 |
+
if ' ' in line: line = line.split()[0]
|
11 |
+
filelist.append(os.path.join(data_root, line))
|
12 |
+
|
13 |
+
return filelist
|
14 |
+
|
15 |
+
class HParams:
|
16 |
+
def __init__(self, **kwargs):
|
17 |
+
self.data = {}
|
18 |
+
|
19 |
+
for key, value in kwargs.items():
|
20 |
+
self.data[key] = value
|
21 |
+
|
22 |
+
def __getattr__(self, key):
|
23 |
+
if key not in self.data:
|
24 |
+
raise AttributeError("'HParams' object has no attribute %s" % key)
|
25 |
+
return self.data[key]
|
26 |
+
|
27 |
+
def set_hparam(self, key, value):
|
28 |
+
self.data[key] = value
|
29 |
+
|
30 |
+
|
31 |
+
# Default hyperparameters
|
32 |
+
hparams = HParams(
|
33 |
+
num_mels=80, # Number of mel-spectrogram channels and local conditioning dimensionality
|
34 |
+
# network
|
35 |
+
rescale=True, # Whether to rescale audio prior to preprocessing
|
36 |
+
rescaling_max=0.9, # Rescaling value
|
37 |
+
|
38 |
+
# Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction
|
39 |
+
# It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder
|
40 |
+
# Does not work if n_ffit is not multiple of hop_size!!
|
41 |
+
use_lws=False,
|
42 |
+
|
43 |
+
n_fft=800, # Extra window size is filled with 0 paddings to match this parameter
|
44 |
+
hop_size=200, # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate)
|
45 |
+
win_size=800, # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate)
|
46 |
+
sample_rate=16000, # 16000Hz (corresponding to librispeech) (sox --i <filename>)
|
47 |
+
|
48 |
+
frame_shift_ms=None, # Can replace hop_size parameter. (Recommended: 12.5)
|
49 |
+
|
50 |
+
# Mel and Linear spectrograms normalization/scaling and clipping
|
51 |
+
signal_normalization=True,
|
52 |
+
# Whether to normalize mel spectrograms to some predefined range (following below parameters)
|
53 |
+
allow_clipping_in_normalization=True, # Only relevant if mel_normalization = True
|
54 |
+
symmetric_mels=True,
|
55 |
+
# Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2,
|
56 |
+
# faster and cleaner convergence)
|
57 |
+
max_abs_value=4.,
|
58 |
+
# max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not
|
59 |
+
# be too big to avoid gradient explosion,
|
60 |
+
# not too small for fast convergence)
|
61 |
+
# Contribution by @begeekmyfriend
|
62 |
+
# Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude
|
63 |
+
# levels. Also allows for better G&L phase reconstruction)
|
64 |
+
preemphasize=True, # whether to apply filter
|
65 |
+
preemphasis=0.97, # filter coefficient.
|
66 |
+
|
67 |
+
# Limits
|
68 |
+
min_level_db=-100,
|
69 |
+
ref_level_db=20,
|
70 |
+
fmin=55,
|
71 |
+
# Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To
|
72 |
+
# test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])
|
73 |
+
fmax=7600, # To be increased/reduced depending on data.
|
74 |
+
|
75 |
+
###################### Our training parameters #################################
|
76 |
+
img_size=96,
|
77 |
+
fps=25,
|
78 |
+
|
79 |
+
batch_size=16,
|
80 |
+
initial_learning_rate=1e-4,
|
81 |
+
nepochs=200000000000000000, ### ctrl + c, stop whenever eval loss is consistently greater than train loss for ~10 epochs
|
82 |
+
num_workers=16,
|
83 |
+
checkpoint_interval=3000,
|
84 |
+
eval_interval=3000,
|
85 |
+
save_optimizer_state=True,
|
86 |
+
|
87 |
+
syncnet_wt=0.0, # is initially zero, will be set automatically to 0.03 later. Leads to faster convergence.
|
88 |
+
syncnet_batch_size=64,
|
89 |
+
syncnet_lr=1e-4,
|
90 |
+
syncnet_eval_interval=10000,
|
91 |
+
syncnet_checkpoint_interval=10000,
|
92 |
+
|
93 |
+
disc_wt=0.07,
|
94 |
+
disc_initial_learning_rate=1e-4,
|
95 |
+
)
|
96 |
+
|
97 |
+
|
98 |
+
def hparams_debug_string():
|
99 |
+
values = hparams.values()
|
100 |
+
hp = [" %s: %s" % (name, values[name]) for name in sorted(values) if name != "sentences"]
|
101 |
+
return "Hyperparameters:\n" + "\n".join(hp)
|
hq_wav2lip_train.py
ADDED
@@ -0,0 +1,443 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from os.path import dirname, join, basename, isfile
|
2 |
+
from tqdm import tqdm
|
3 |
+
|
4 |
+
from models import SyncNet_color as SyncNet
|
5 |
+
from models import Wav2Lip, Wav2Lip_disc_qual
|
6 |
+
import audio
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from torch import nn
|
10 |
+
from torch.nn import functional as F
|
11 |
+
from torch import optim
|
12 |
+
import torch.backends.cudnn as cudnn
|
13 |
+
from torch.utils import data as data_utils
|
14 |
+
import numpy as np
|
15 |
+
|
16 |
+
from glob import glob
|
17 |
+
|
18 |
+
import os, random, cv2, argparse
|
19 |
+
from hparams import hparams, get_image_list
|
20 |
+
|
21 |
+
parser = argparse.ArgumentParser(description='Code to train the Wav2Lip model WITH the visual quality discriminator')
|
22 |
+
|
23 |
+
parser.add_argument("--data_root", help="Root folder of the preprocessed LRS2 dataset", required=True, type=str)
|
24 |
+
|
25 |
+
parser.add_argument('--checkpoint_dir', help='Save checkpoints to this directory', required=True, type=str)
|
26 |
+
parser.add_argument('--syncnet_checkpoint_path', help='Load the pre-trained Expert discriminator', required=True, type=str)
|
27 |
+
|
28 |
+
parser.add_argument('--checkpoint_path', help='Resume generator from this checkpoint', default=None, type=str)
|
29 |
+
parser.add_argument('--disc_checkpoint_path', help='Resume quality disc from this checkpoint', default=None, type=str)
|
30 |
+
|
31 |
+
args = parser.parse_args()
|
32 |
+
|
33 |
+
|
34 |
+
global_step = 0
|
35 |
+
global_epoch = 0
|
36 |
+
use_cuda = torch.cuda.is_available()
|
37 |
+
print('use_cuda: {}'.format(use_cuda))
|
38 |
+
|
39 |
+
syncnet_T = 5
|
40 |
+
syncnet_mel_step_size = 16
|
41 |
+
|
42 |
+
class Dataset(object):
|
43 |
+
def __init__(self, split):
|
44 |
+
self.all_videos = get_image_list(args.data_root, split)
|
45 |
+
|
46 |
+
def get_frame_id(self, frame):
|
47 |
+
return int(basename(frame).split('.')[0])
|
48 |
+
|
49 |
+
def get_window(self, start_frame):
|
50 |
+
start_id = self.get_frame_id(start_frame)
|
51 |
+
vidname = dirname(start_frame)
|
52 |
+
|
53 |
+
window_fnames = []
|
54 |
+
for frame_id in range(start_id, start_id + syncnet_T):
|
55 |
+
frame = join(vidname, '{}.jpg'.format(frame_id))
|
56 |
+
if not isfile(frame):
|
57 |
+
return None
|
58 |
+
window_fnames.append(frame)
|
59 |
+
return window_fnames
|
60 |
+
|
61 |
+
def read_window(self, window_fnames):
|
62 |
+
if window_fnames is None: return None
|
63 |
+
window = []
|
64 |
+
for fname in window_fnames:
|
65 |
+
img = cv2.imread(fname)
|
66 |
+
if img is None:
|
67 |
+
return None
|
68 |
+
try:
|
69 |
+
img = cv2.resize(img, (hparams.img_size, hparams.img_size))
|
70 |
+
except Exception as e:
|
71 |
+
return None
|
72 |
+
|
73 |
+
window.append(img)
|
74 |
+
|
75 |
+
return window
|
76 |
+
|
77 |
+
def crop_audio_window(self, spec, start_frame):
|
78 |
+
if type(start_frame) == int:
|
79 |
+
start_frame_num = start_frame
|
80 |
+
else:
|
81 |
+
start_frame_num = self.get_frame_id(start_frame)
|
82 |
+
start_idx = int(80. * (start_frame_num / float(hparams.fps)))
|
83 |
+
|
84 |
+
end_idx = start_idx + syncnet_mel_step_size
|
85 |
+
|
86 |
+
return spec[start_idx : end_idx, :]
|
87 |
+
|
88 |
+
def get_segmented_mels(self, spec, start_frame):
|
89 |
+
mels = []
|
90 |
+
assert syncnet_T == 5
|
91 |
+
start_frame_num = self.get_frame_id(start_frame) + 1 # 0-indexing ---> 1-indexing
|
92 |
+
if start_frame_num - 2 < 0: return None
|
93 |
+
for i in range(start_frame_num, start_frame_num + syncnet_T):
|
94 |
+
m = self.crop_audio_window(spec, i - 2)
|
95 |
+
if m.shape[0] != syncnet_mel_step_size:
|
96 |
+
return None
|
97 |
+
mels.append(m.T)
|
98 |
+
|
99 |
+
mels = np.asarray(mels)
|
100 |
+
|
101 |
+
return mels
|
102 |
+
|
103 |
+
def prepare_window(self, window):
|
104 |
+
# 3 x T x H x W
|
105 |
+
x = np.asarray(window) / 255.
|
106 |
+
x = np.transpose(x, (3, 0, 1, 2))
|
107 |
+
|
108 |
+
return x
|
109 |
+
|
110 |
+
def __len__(self):
|
111 |
+
return len(self.all_videos)
|
112 |
+
|
113 |
+
def __getitem__(self, idx):
|
114 |
+
while 1:
|
115 |
+
idx = random.randint(0, len(self.all_videos) - 1)
|
116 |
+
vidname = self.all_videos[idx]
|
117 |
+
img_names = list(glob(join(vidname, '*.jpg')))
|
118 |
+
if len(img_names) <= 3 * syncnet_T:
|
119 |
+
continue
|
120 |
+
|
121 |
+
img_name = random.choice(img_names)
|
122 |
+
wrong_img_name = random.choice(img_names)
|
123 |
+
while wrong_img_name == img_name:
|
124 |
+
wrong_img_name = random.choice(img_names)
|
125 |
+
|
126 |
+
window_fnames = self.get_window(img_name)
|
127 |
+
wrong_window_fnames = self.get_window(wrong_img_name)
|
128 |
+
if window_fnames is None or wrong_window_fnames is None:
|
129 |
+
continue
|
130 |
+
|
131 |
+
window = self.read_window(window_fnames)
|
132 |
+
if window is None:
|
133 |
+
continue
|
134 |
+
|
135 |
+
wrong_window = self.read_window(wrong_window_fnames)
|
136 |
+
if wrong_window is None:
|
137 |
+
continue
|
138 |
+
|
139 |
+
try:
|
140 |
+
wavpath = join(vidname, "audio.wav")
|
141 |
+
wav = audio.load_wav(wavpath, hparams.sample_rate)
|
142 |
+
|
143 |
+
orig_mel = audio.melspectrogram(wav).T
|
144 |
+
except Exception as e:
|
145 |
+
continue
|
146 |
+
|
147 |
+
mel = self.crop_audio_window(orig_mel.copy(), img_name)
|
148 |
+
|
149 |
+
if (mel.shape[0] != syncnet_mel_step_size):
|
150 |
+
continue
|
151 |
+
|
152 |
+
indiv_mels = self.get_segmented_mels(orig_mel.copy(), img_name)
|
153 |
+
if indiv_mels is None: continue
|
154 |
+
|
155 |
+
window = self.prepare_window(window)
|
156 |
+
y = window.copy()
|
157 |
+
window[:, :, window.shape[2]//2:] = 0.
|
158 |
+
|
159 |
+
wrong_window = self.prepare_window(wrong_window)
|
160 |
+
x = np.concatenate([window, wrong_window], axis=0)
|
161 |
+
|
162 |
+
x = torch.FloatTensor(x)
|
163 |
+
mel = torch.FloatTensor(mel.T).unsqueeze(0)
|
164 |
+
indiv_mels = torch.FloatTensor(indiv_mels).unsqueeze(1)
|
165 |
+
y = torch.FloatTensor(y)
|
166 |
+
return x, indiv_mels, mel, y
|
167 |
+
|
168 |
+
def save_sample_images(x, g, gt, global_step, checkpoint_dir):
|
169 |
+
x = (x.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8)
|
170 |
+
g = (g.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8)
|
171 |
+
gt = (gt.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8)
|
172 |
+
|
173 |
+
refs, inps = x[..., 3:], x[..., :3]
|
174 |
+
folder = join(checkpoint_dir, "samples_step{:09d}".format(global_step))
|
175 |
+
if not os.path.exists(folder): os.mkdir(folder)
|
176 |
+
collage = np.concatenate((refs, inps, g, gt), axis=-2)
|
177 |
+
for batch_idx, c in enumerate(collage):
|
178 |
+
for t in range(len(c)):
|
179 |
+
cv2.imwrite('{}/{}_{}.jpg'.format(folder, batch_idx, t), c[t])
|
180 |
+
|
181 |
+
logloss = nn.BCELoss()
|
182 |
+
def cosine_loss(a, v, y):
|
183 |
+
d = nn.functional.cosine_similarity(a, v)
|
184 |
+
loss = logloss(d.unsqueeze(1), y)
|
185 |
+
|
186 |
+
return loss
|
187 |
+
|
188 |
+
device = torch.device("cuda" if use_cuda else "cpu")
|
189 |
+
syncnet = SyncNet().to(device)
|
190 |
+
for p in syncnet.parameters():
|
191 |
+
p.requires_grad = False
|
192 |
+
|
193 |
+
recon_loss = nn.L1Loss()
|
194 |
+
def get_sync_loss(mel, g):
|
195 |
+
g = g[:, :, :, g.size(3)//2:]
|
196 |
+
g = torch.cat([g[:, :, i] for i in range(syncnet_T)], dim=1)
|
197 |
+
# B, 3 * T, H//2, W
|
198 |
+
a, v = syncnet(mel, g)
|
199 |
+
y = torch.ones(g.size(0), 1).float().to(device)
|
200 |
+
return cosine_loss(a, v, y)
|
201 |
+
|
202 |
+
def train(device, model, disc, train_data_loader, test_data_loader, optimizer, disc_optimizer,
|
203 |
+
checkpoint_dir=None, checkpoint_interval=None, nepochs=None):
|
204 |
+
global global_step, global_epoch
|
205 |
+
resumed_step = global_step
|
206 |
+
|
207 |
+
while global_epoch < nepochs:
|
208 |
+
print('Starting Epoch: {}'.format(global_epoch))
|
209 |
+
running_sync_loss, running_l1_loss, disc_loss, running_perceptual_loss = 0., 0., 0., 0.
|
210 |
+
running_disc_real_loss, running_disc_fake_loss = 0., 0.
|
211 |
+
prog_bar = tqdm(enumerate(train_data_loader))
|
212 |
+
for step, (x, indiv_mels, mel, gt) in prog_bar:
|
213 |
+
disc.train()
|
214 |
+
model.train()
|
215 |
+
|
216 |
+
x = x.to(device)
|
217 |
+
mel = mel.to(device)
|
218 |
+
indiv_mels = indiv_mels.to(device)
|
219 |
+
gt = gt.to(device)
|
220 |
+
|
221 |
+
### Train generator now. Remove ALL grads.
|
222 |
+
optimizer.zero_grad()
|
223 |
+
disc_optimizer.zero_grad()
|
224 |
+
|
225 |
+
g = model(indiv_mels, x)
|
226 |
+
|
227 |
+
if hparams.syncnet_wt > 0.:
|
228 |
+
sync_loss = get_sync_loss(mel, g)
|
229 |
+
else:
|
230 |
+
sync_loss = 0.
|
231 |
+
|
232 |
+
if hparams.disc_wt > 0.:
|
233 |
+
perceptual_loss = disc.perceptual_forward(g)
|
234 |
+
else:
|
235 |
+
perceptual_loss = 0.
|
236 |
+
|
237 |
+
l1loss = recon_loss(g, gt)
|
238 |
+
|
239 |
+
loss = hparams.syncnet_wt * sync_loss + hparams.disc_wt * perceptual_loss + \
|
240 |
+
(1. - hparams.syncnet_wt - hparams.disc_wt) * l1loss
|
241 |
+
|
242 |
+
loss.backward()
|
243 |
+
optimizer.step()
|
244 |
+
|
245 |
+
### Remove all gradients before Training disc
|
246 |
+
disc_optimizer.zero_grad()
|
247 |
+
|
248 |
+
pred = disc(gt)
|
249 |
+
disc_real_loss = F.binary_cross_entropy(pred, torch.ones((len(pred), 1)).to(device))
|
250 |
+
disc_real_loss.backward()
|
251 |
+
|
252 |
+
pred = disc(g.detach())
|
253 |
+
disc_fake_loss = F.binary_cross_entropy(pred, torch.zeros((len(pred), 1)).to(device))
|
254 |
+
disc_fake_loss.backward()
|
255 |
+
|
256 |
+
disc_optimizer.step()
|
257 |
+
|
258 |
+
running_disc_real_loss += disc_real_loss.item()
|
259 |
+
running_disc_fake_loss += disc_fake_loss.item()
|
260 |
+
|
261 |
+
if global_step % checkpoint_interval == 0:
|
262 |
+
save_sample_images(x, g, gt, global_step, checkpoint_dir)
|
263 |
+
|
264 |
+
# Logs
|
265 |
+
global_step += 1
|
266 |
+
cur_session_steps = global_step - resumed_step
|
267 |
+
|
268 |
+
running_l1_loss += l1loss.item()
|
269 |
+
if hparams.syncnet_wt > 0.:
|
270 |
+
running_sync_loss += sync_loss.item()
|
271 |
+
else:
|
272 |
+
running_sync_loss += 0.
|
273 |
+
|
274 |
+
if hparams.disc_wt > 0.:
|
275 |
+
running_perceptual_loss += perceptual_loss.item()
|
276 |
+
else:
|
277 |
+
running_perceptual_loss += 0.
|
278 |
+
|
279 |
+
if global_step == 1 or global_step % checkpoint_interval == 0:
|
280 |
+
save_checkpoint(
|
281 |
+
model, optimizer, global_step, checkpoint_dir, global_epoch)
|
282 |
+
save_checkpoint(disc, disc_optimizer, global_step, checkpoint_dir, global_epoch, prefix='disc_')
|
283 |
+
|
284 |
+
|
285 |
+
if global_step % hparams.eval_interval == 0:
|
286 |
+
with torch.no_grad():
|
287 |
+
average_sync_loss = eval_model(test_data_loader, global_step, device, model, disc)
|
288 |
+
|
289 |
+
if average_sync_loss < .75:
|
290 |
+
hparams.set_hparam('syncnet_wt', 0.03)
|
291 |
+
|
292 |
+
prog_bar.set_description('L1: {}, Sync: {}, Percep: {} | Fake: {}, Real: {}'.format(running_l1_loss / (step + 1),
|
293 |
+
running_sync_loss / (step + 1),
|
294 |
+
running_perceptual_loss / (step + 1),
|
295 |
+
running_disc_fake_loss / (step + 1),
|
296 |
+
running_disc_real_loss / (step + 1)))
|
297 |
+
|
298 |
+
global_epoch += 1
|
299 |
+
|
300 |
+
def eval_model(test_data_loader, global_step, device, model, disc):
|
301 |
+
eval_steps = 300
|
302 |
+
print('Evaluating for {} steps'.format(eval_steps))
|
303 |
+
running_sync_loss, running_l1_loss, running_disc_real_loss, running_disc_fake_loss, running_perceptual_loss = [], [], [], [], []
|
304 |
+
while 1:
|
305 |
+
for step, (x, indiv_mels, mel, gt) in enumerate((test_data_loader)):
|
306 |
+
model.eval()
|
307 |
+
disc.eval()
|
308 |
+
|
309 |
+
x = x.to(device)
|
310 |
+
mel = mel.to(device)
|
311 |
+
indiv_mels = indiv_mels.to(device)
|
312 |
+
gt = gt.to(device)
|
313 |
+
|
314 |
+
pred = disc(gt)
|
315 |
+
disc_real_loss = F.binary_cross_entropy(pred, torch.ones((len(pred), 1)).to(device))
|
316 |
+
|
317 |
+
g = model(indiv_mels, x)
|
318 |
+
pred = disc(g)
|
319 |
+
disc_fake_loss = F.binary_cross_entropy(pred, torch.zeros((len(pred), 1)).to(device))
|
320 |
+
|
321 |
+
running_disc_real_loss.append(disc_real_loss.item())
|
322 |
+
running_disc_fake_loss.append(disc_fake_loss.item())
|
323 |
+
|
324 |
+
sync_loss = get_sync_loss(mel, g)
|
325 |
+
|
326 |
+
if hparams.disc_wt > 0.:
|
327 |
+
perceptual_loss = disc.perceptual_forward(g)
|
328 |
+
else:
|
329 |
+
perceptual_loss = 0.
|
330 |
+
|
331 |
+
l1loss = recon_loss(g, gt)
|
332 |
+
|
333 |
+
loss = hparams.syncnet_wt * sync_loss + hparams.disc_wt * perceptual_loss + \
|
334 |
+
(1. - hparams.syncnet_wt - hparams.disc_wt) * l1loss
|
335 |
+
|
336 |
+
running_l1_loss.append(l1loss.item())
|
337 |
+
running_sync_loss.append(sync_loss.item())
|
338 |
+
|
339 |
+
if hparams.disc_wt > 0.:
|
340 |
+
running_perceptual_loss.append(perceptual_loss.item())
|
341 |
+
else:
|
342 |
+
running_perceptual_loss.append(0.)
|
343 |
+
|
344 |
+
if step > eval_steps: break
|
345 |
+
|
346 |
+
print('L1: {}, Sync: {}, Percep: {} | Fake: {}, Real: {}'.format(sum(running_l1_loss) / len(running_l1_loss),
|
347 |
+
sum(running_sync_loss) / len(running_sync_loss),
|
348 |
+
sum(running_perceptual_loss) / len(running_perceptual_loss),
|
349 |
+
sum(running_disc_fake_loss) / len(running_disc_fake_loss),
|
350 |
+
sum(running_disc_real_loss) / len(running_disc_real_loss)))
|
351 |
+
return sum(running_sync_loss) / len(running_sync_loss)
|
352 |
+
|
353 |
+
|
354 |
+
def save_checkpoint(model, optimizer, step, checkpoint_dir, epoch, prefix=''):
|
355 |
+
checkpoint_path = join(
|
356 |
+
checkpoint_dir, "{}checkpoint_step{:09d}.pth".format(prefix, global_step))
|
357 |
+
optimizer_state = optimizer.state_dict() if hparams.save_optimizer_state else None
|
358 |
+
torch.save({
|
359 |
+
"state_dict": model.state_dict(),
|
360 |
+
"optimizer": optimizer_state,
|
361 |
+
"global_step": step,
|
362 |
+
"global_epoch": epoch,
|
363 |
+
}, checkpoint_path)
|
364 |
+
print("Saved checkpoint:", checkpoint_path)
|
365 |
+
|
366 |
+
def _load(checkpoint_path):
|
367 |
+
if use_cuda:
|
368 |
+
checkpoint = torch.load(checkpoint_path)
|
369 |
+
else:
|
370 |
+
checkpoint = torch.load(checkpoint_path,
|
371 |
+
map_location=lambda storage, loc: storage)
|
372 |
+
return checkpoint
|
373 |
+
|
374 |
+
|
375 |
+
def load_checkpoint(path, model, optimizer, reset_optimizer=False, overwrite_global_states=True):
|
376 |
+
global global_step
|
377 |
+
global global_epoch
|
378 |
+
|
379 |
+
print("Load checkpoint from: {}".format(path))
|
380 |
+
checkpoint = _load(path)
|
381 |
+
s = checkpoint["state_dict"]
|
382 |
+
new_s = {}
|
383 |
+
for k, v in s.items():
|
384 |
+
new_s[k.replace('module.', '')] = v
|
385 |
+
model.load_state_dict(new_s)
|
386 |
+
if not reset_optimizer:
|
387 |
+
optimizer_state = checkpoint["optimizer"]
|
388 |
+
if optimizer_state is not None:
|
389 |
+
print("Load optimizer state from {}".format(path))
|
390 |
+
optimizer.load_state_dict(checkpoint["optimizer"])
|
391 |
+
if overwrite_global_states:
|
392 |
+
global_step = checkpoint["global_step"]
|
393 |
+
global_epoch = checkpoint["global_epoch"]
|
394 |
+
|
395 |
+
return model
|
396 |
+
|
397 |
+
if __name__ == "__main__":
|
398 |
+
checkpoint_dir = args.checkpoint_dir
|
399 |
+
|
400 |
+
# Dataset and Dataloader setup
|
401 |
+
train_dataset = Dataset('train')
|
402 |
+
test_dataset = Dataset('val')
|
403 |
+
|
404 |
+
train_data_loader = data_utils.DataLoader(
|
405 |
+
train_dataset, batch_size=hparams.batch_size, shuffle=True,
|
406 |
+
num_workers=hparams.num_workers)
|
407 |
+
|
408 |
+
test_data_loader = data_utils.DataLoader(
|
409 |
+
test_dataset, batch_size=hparams.batch_size,
|
410 |
+
num_workers=4)
|
411 |
+
|
412 |
+
device = torch.device("cuda" if use_cuda else "cpu")
|
413 |
+
|
414 |
+
# Model
|
415 |
+
model = Wav2Lip().to(device)
|
416 |
+
disc = Wav2Lip_disc_qual().to(device)
|
417 |
+
|
418 |
+
print('total trainable params {}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad)))
|
419 |
+
print('total DISC trainable params {}'.format(sum(p.numel() for p in disc.parameters() if p.requires_grad)))
|
420 |
+
|
421 |
+
optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad],
|
422 |
+
lr=hparams.initial_learning_rate, betas=(0.5, 0.999))
|
423 |
+
disc_optimizer = optim.Adam([p for p in disc.parameters() if p.requires_grad],
|
424 |
+
lr=hparams.disc_initial_learning_rate, betas=(0.5, 0.999))
|
425 |
+
|
426 |
+
if args.checkpoint_path is not None:
|
427 |
+
load_checkpoint(args.checkpoint_path, model, optimizer, reset_optimizer=False)
|
428 |
+
|
429 |
+
if args.disc_checkpoint_path is not None:
|
430 |
+
load_checkpoint(args.disc_checkpoint_path, disc, disc_optimizer,
|
431 |
+
reset_optimizer=False, overwrite_global_states=False)
|
432 |
+
|
433 |
+
load_checkpoint(args.syncnet_checkpoint_path, syncnet, None, reset_optimizer=True,
|
434 |
+
overwrite_global_states=False)
|
435 |
+
|
436 |
+
if not os.path.exists(checkpoint_dir):
|
437 |
+
os.mkdir(checkpoint_dir)
|
438 |
+
|
439 |
+
# Train!
|
440 |
+
train(device, model, disc, train_data_loader, test_data_loader, optimizer, disc_optimizer,
|
441 |
+
checkpoint_dir=checkpoint_dir,
|
442 |
+
checkpoint_interval=hparams.checkpoint_interval,
|
443 |
+
nepochs=hparams.nepochs)
|
inference.py
ADDED
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import math
|
3 |
+
import os
|
4 |
+
import platform
|
5 |
+
import subprocess
|
6 |
+
|
7 |
+
import cv2
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
import audio
|
13 |
+
# from face_detect import face_rect
|
14 |
+
from models import Wav2Lip
|
15 |
+
|
16 |
+
from batch_face import RetinaFace
|
17 |
+
from time import time
|
18 |
+
|
19 |
+
parser = argparse.ArgumentParser(description='Inference code to lip-sync videos in the wild using Wav2Lip models')
|
20 |
+
|
21 |
+
parser.add_argument('--checkpoint_path', type=str,
|
22 |
+
help='Name of saved checkpoint to load weights from', required=True)
|
23 |
+
|
24 |
+
parser.add_argument('--face', type=str,
|
25 |
+
help='Filepath of video/image that contains faces to use', required=True)
|
26 |
+
parser.add_argument('--audio', type=str,
|
27 |
+
help='Filepath of video/audio file to use as raw audio source', required=True)
|
28 |
+
parser.add_argument('--outfile', type=str, help='Video path to save result. See default for an e.g.',
|
29 |
+
default='results/result_voice.mp4')
|
30 |
+
|
31 |
+
parser.add_argument('--static', type=bool,
|
32 |
+
help='If True, then use only first video frame for inference', default=False)
|
33 |
+
parser.add_argument('--fps', type=float, help='Can be specified only if input is a static image (default: 25)',
|
34 |
+
default=25., required=False)
|
35 |
+
|
36 |
+
parser.add_argument('--pads', nargs='+', type=int, default=[0, 10, 0, 0],
|
37 |
+
help='Padding (top, bottom, left, right). Please adjust to include chin at least')
|
38 |
+
|
39 |
+
parser.add_argument('--wav2lip_batch_size', type=int, help='Batch size for Wav2Lip model(s)', default=128)
|
40 |
+
|
41 |
+
parser.add_argument('--resize_factor', default=1, type=int,
|
42 |
+
help='Reduce the resolution by this factor. Sometimes, best results are obtained at 480p or 720p')
|
43 |
+
|
44 |
+
parser.add_argument('--out_height', default=480, type=int,
|
45 |
+
help='Output video height. Best results are obtained at 480 or 720')
|
46 |
+
|
47 |
+
parser.add_argument('--crop', nargs='+', type=int, default=[0, -1, 0, -1],
|
48 |
+
help='Crop video to a smaller region (top, bottom, left, right). Applied after resize_factor and rotate arg. '
|
49 |
+
'Useful if multiple face present. -1 implies the value will be auto-inferred based on height, width')
|
50 |
+
|
51 |
+
parser.add_argument('--box', nargs='+', type=int, default=[-1, -1, -1, -1],
|
52 |
+
help='Specify a constant bounding box for the face. Use only as a last resort if the face is not detected.'
|
53 |
+
'Also, might work only if the face is not moving around much. Syntax: (top, bottom, left, right).')
|
54 |
+
|
55 |
+
parser.add_argument('--rotate', default=False, action='store_true',
|
56 |
+
help='Sometimes videos taken from a phone can be flipped 90deg. If true, will flip video right by 90deg.'
|
57 |
+
'Use if you get a flipped result, despite feeding a normal looking video')
|
58 |
+
|
59 |
+
parser.add_argument('--nosmooth', default=False, action='store_true',
|
60 |
+
help='Prevent smoothing face detections over a short temporal window')
|
61 |
+
|
62 |
+
|
63 |
+
def get_smoothened_boxes(boxes, T):
|
64 |
+
for i in range(len(boxes)):
|
65 |
+
if i + T > len(boxes):
|
66 |
+
window = boxes[len(boxes) - T:]
|
67 |
+
else:
|
68 |
+
window = boxes[i : i + T]
|
69 |
+
boxes[i] = np.mean(window, axis=0)
|
70 |
+
return boxes
|
71 |
+
|
72 |
+
def face_detect(images):
|
73 |
+
results = []
|
74 |
+
pady1, pady2, padx1, padx2 = args.pads
|
75 |
+
|
76 |
+
s = time()
|
77 |
+
|
78 |
+
for image, rect in zip(images, face_rect(images)):
|
79 |
+
if rect is None:
|
80 |
+
cv2.imwrite('temp/faulty_frame.jpg', image) # check this frame where the face was not detected.
|
81 |
+
raise ValueError('Face not detected! Ensure the video contains a face in all the frames.')
|
82 |
+
|
83 |
+
y1 = max(0, rect[1] - pady1)
|
84 |
+
y2 = min(image.shape[0], rect[3] + pady2)
|
85 |
+
x1 = max(0, rect[0] - padx1)
|
86 |
+
x2 = min(image.shape[1], rect[2] + padx2)
|
87 |
+
|
88 |
+
results.append([x1, y1, x2, y2])
|
89 |
+
|
90 |
+
print('face detect time:', time() - s)
|
91 |
+
|
92 |
+
boxes = np.array(results)
|
93 |
+
if not args.nosmooth: boxes = get_smoothened_boxes(boxes, T=5)
|
94 |
+
results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)]
|
95 |
+
|
96 |
+
return results
|
97 |
+
|
98 |
+
|
99 |
+
def datagen(frames, mels):
|
100 |
+
img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
|
101 |
+
|
102 |
+
if args.box[0] == -1:
|
103 |
+
if not args.static:
|
104 |
+
face_det_results = face_detect(frames) # BGR2RGB for CNN face detection
|
105 |
+
else:
|
106 |
+
face_det_results = face_detect([frames[0]])
|
107 |
+
else:
|
108 |
+
print('Using the specified bounding box instead of face detection...')
|
109 |
+
y1, y2, x1, x2 = args.box
|
110 |
+
face_det_results = [[f[y1: y2, x1:x2], (y1, y2, x1, x2)] for f in frames]
|
111 |
+
|
112 |
+
for i, m in enumerate(mels):
|
113 |
+
idx = 0 if args.static else i%len(frames)
|
114 |
+
frame_to_save = frames[idx].copy()
|
115 |
+
face, coords = face_det_results[idx].copy()
|
116 |
+
|
117 |
+
face = cv2.resize(face, (args.img_size, args.img_size))
|
118 |
+
|
119 |
+
img_batch.append(face)
|
120 |
+
mel_batch.append(m)
|
121 |
+
frame_batch.append(frame_to_save)
|
122 |
+
coords_batch.append(coords)
|
123 |
+
|
124 |
+
if len(img_batch) >= args.wav2lip_batch_size:
|
125 |
+
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
|
126 |
+
|
127 |
+
img_masked = img_batch.copy()
|
128 |
+
img_masked[:, args.img_size//2:] = 0
|
129 |
+
|
130 |
+
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
|
131 |
+
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
|
132 |
+
|
133 |
+
yield img_batch, mel_batch, frame_batch, coords_batch
|
134 |
+
img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
|
135 |
+
|
136 |
+
if len(img_batch) > 0:
|
137 |
+
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
|
138 |
+
|
139 |
+
img_masked = img_batch.copy()
|
140 |
+
img_masked[:, args.img_size//2:] = 0
|
141 |
+
|
142 |
+
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
|
143 |
+
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
|
144 |
+
|
145 |
+
yield img_batch, mel_batch, frame_batch, coords_batch
|
146 |
+
|
147 |
+
mel_step_size = 16
|
148 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
149 |
+
print('Using {} for inference.'.format(device))
|
150 |
+
|
151 |
+
def _load(checkpoint_path):
|
152 |
+
if device == 'cuda':
|
153 |
+
checkpoint = torch.load(checkpoint_path)
|
154 |
+
else:
|
155 |
+
checkpoint = torch.load(checkpoint_path,
|
156 |
+
map_location=lambda storage, loc: storage)
|
157 |
+
return checkpoint
|
158 |
+
|
159 |
+
def load_model(path):
|
160 |
+
model = Wav2Lip()
|
161 |
+
print("Load checkpoint from: {}".format(path))
|
162 |
+
checkpoint = _load(path)
|
163 |
+
s = checkpoint["state_dict"]
|
164 |
+
new_s = {}
|
165 |
+
for k, v in s.items():
|
166 |
+
new_s[k.replace('module.', '')] = v
|
167 |
+
model.load_state_dict(new_s)
|
168 |
+
|
169 |
+
model = model.to(device)
|
170 |
+
return model.eval()
|
171 |
+
|
172 |
+
def main():
|
173 |
+
args.img_size = 96
|
174 |
+
|
175 |
+
if os.path.isfile(args.face) and args.face.split('.')[1] in ['jpg', 'png', 'jpeg']:
|
176 |
+
args.static = True
|
177 |
+
|
178 |
+
if not os.path.isfile(args.face):
|
179 |
+
raise ValueError('--face argument must be a valid path to video/image file')
|
180 |
+
|
181 |
+
elif args.face.split('.')[1] in ['jpg', 'png', 'jpeg']:
|
182 |
+
full_frames = [cv2.imread(args.face)]
|
183 |
+
fps = args.fps
|
184 |
+
|
185 |
+
else:
|
186 |
+
video_stream = cv2.VideoCapture(args.face)
|
187 |
+
fps = video_stream.get(cv2.CAP_PROP_FPS)
|
188 |
+
|
189 |
+
print('Reading video frames...')
|
190 |
+
|
191 |
+
full_frames = []
|
192 |
+
while 1:
|
193 |
+
still_reading, frame = video_stream.read()
|
194 |
+
if not still_reading:
|
195 |
+
video_stream.release()
|
196 |
+
break
|
197 |
+
|
198 |
+
aspect_ratio = frame.shape[1] / frame.shape[0]
|
199 |
+
frame = cv2.resize(frame, (int(args.out_height * aspect_ratio), args.out_height))
|
200 |
+
# if args.resize_factor > 1:
|
201 |
+
# frame = cv2.resize(frame, (frame.shape[1]//args.resize_factor, frame.shape[0]//args.resize_factor))
|
202 |
+
|
203 |
+
if args.rotate:
|
204 |
+
frame = cv2.rotate(frame, cv2.cv2.ROTATE_90_CLOCKWISE)
|
205 |
+
|
206 |
+
y1, y2, x1, x2 = args.crop
|
207 |
+
if x2 == -1: x2 = frame.shape[1]
|
208 |
+
if y2 == -1: y2 = frame.shape[0]
|
209 |
+
|
210 |
+
frame = frame[y1:y2, x1:x2]
|
211 |
+
|
212 |
+
full_frames.append(frame)
|
213 |
+
|
214 |
+
print ("Number of frames available for inference: "+str(len(full_frames)))
|
215 |
+
|
216 |
+
if not args.audio.endswith('.wav'):
|
217 |
+
print('Extracting raw audio...')
|
218 |
+
# command = 'ffmpeg -y -i {} -strict -2 {}'.format(args.audio, 'temp/temp.wav')
|
219 |
+
# subprocess.call(command, shell=True)
|
220 |
+
subprocess.check_call([
|
221 |
+
"ffmpeg", "-y",
|
222 |
+
"-i", args.audio,
|
223 |
+
"temp/temp.wav",
|
224 |
+
])
|
225 |
+
args.audio = 'temp/temp.wav'
|
226 |
+
|
227 |
+
wav = audio.load_wav(args.audio, 16000)
|
228 |
+
mel = audio.melspectrogram(wav)
|
229 |
+
print(mel.shape)
|
230 |
+
|
231 |
+
if np.isnan(mel.reshape(-1)).sum() > 0:
|
232 |
+
raise ValueError('Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again')
|
233 |
+
|
234 |
+
mel_chunks = []
|
235 |
+
mel_idx_multiplier = 80./fps
|
236 |
+
i = 0
|
237 |
+
while 1:
|
238 |
+
start_idx = int(i * mel_idx_multiplier)
|
239 |
+
if start_idx + mel_step_size > len(mel[0]):
|
240 |
+
mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:])
|
241 |
+
break
|
242 |
+
mel_chunks.append(mel[:, start_idx : start_idx + mel_step_size])
|
243 |
+
i += 1
|
244 |
+
|
245 |
+
print("Length of mel chunks: {}".format(len(mel_chunks)))
|
246 |
+
|
247 |
+
full_frames = full_frames[:len(mel_chunks)]
|
248 |
+
|
249 |
+
batch_size = args.wav2lip_batch_size
|
250 |
+
gen = datagen(full_frames.copy(), mel_chunks)
|
251 |
+
|
252 |
+
s = time()
|
253 |
+
|
254 |
+
for i, (img_batch, mel_batch, frames, coords) in enumerate(tqdm(gen,
|
255 |
+
total=int(np.ceil(float(len(mel_chunks))/batch_size)))):
|
256 |
+
if i == 0:
|
257 |
+
frame_h, frame_w = full_frames[0].shape[:-1]
|
258 |
+
out = cv2.VideoWriter('temp/result.avi',
|
259 |
+
cv2.VideoWriter_fourcc(*'DIVX'), fps, (frame_w, frame_h))
|
260 |
+
|
261 |
+
img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
|
262 |
+
mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
|
263 |
+
|
264 |
+
with torch.no_grad():
|
265 |
+
pred = model(mel_batch, img_batch)
|
266 |
+
|
267 |
+
pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
|
268 |
+
|
269 |
+
for p, f, c in zip(pred, frames, coords):
|
270 |
+
y1, y2, x1, x2 = c
|
271 |
+
p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1))
|
272 |
+
|
273 |
+
f[y1:y2, x1:x2] = p
|
274 |
+
out.write(f)
|
275 |
+
|
276 |
+
out.release()
|
277 |
+
|
278 |
+
print("wav2lip prediction time:", time() - s)
|
279 |
+
|
280 |
+
subprocess.check_call([
|
281 |
+
"ffmpeg", "-y",
|
282 |
+
# "-vsync", "0", "-hwaccel", "cuda", "-hwaccel_output_format", "cuda",
|
283 |
+
"-i", "temp/result.avi",
|
284 |
+
"-i", args.audio,
|
285 |
+
# "-c:v", "h264_nvenc",
|
286 |
+
args.outfile,
|
287 |
+
])
|
288 |
+
|
289 |
+
model = detector = detector_model = None
|
290 |
+
|
291 |
+
def do_load(checkpoint_path):
|
292 |
+
global model, detector, detector_model
|
293 |
+
|
294 |
+
model = load_model(checkpoint_path)
|
295 |
+
|
296 |
+
# SFDDetector.load_model(device)
|
297 |
+
detector = RetinaFace(gpu_id=0, model_path="checkpoints/mobilenet.pth", network="mobilenet")
|
298 |
+
# detector = RetinaFace(gpu_id=0, model_path="checkpoints/resnet50.pth", network="resnet50")
|
299 |
+
|
300 |
+
detector_model = detector.model
|
301 |
+
|
302 |
+
print("Models loaded")
|
303 |
+
|
304 |
+
|
305 |
+
face_batch_size = 64 * 8
|
306 |
+
|
307 |
+
def face_rect(images):
|
308 |
+
num_batches = math.ceil(len(images) / face_batch_size)
|
309 |
+
prev_ret = None
|
310 |
+
for i in range(num_batches):
|
311 |
+
batch = images[i * face_batch_size: (i + 1) * face_batch_size]
|
312 |
+
all_faces = detector(batch) # return faces list of all images
|
313 |
+
for faces in all_faces:
|
314 |
+
if faces:
|
315 |
+
box, landmarks, score = faces[0]
|
316 |
+
prev_ret = tuple(map(int, box))
|
317 |
+
yield prev_ret
|
318 |
+
|
319 |
+
|
320 |
+
if __name__ == '__main__':
|
321 |
+
args = parser.parse_args()
|
322 |
+
do_load(args.checkpoint_path)
|
323 |
+
main()
|
models/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .wav2lip import Wav2Lip, Wav2Lip_disc_qual
|
2 |
+
from .syncnet import SyncNet_color
|
models/conv.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
class Conv2d(nn.Module):
|
6 |
+
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
|
7 |
+
super().__init__(*args, **kwargs)
|
8 |
+
self.conv_block = nn.Sequential(
|
9 |
+
nn.Conv2d(cin, cout, kernel_size, stride, padding),
|
10 |
+
nn.BatchNorm2d(cout)
|
11 |
+
)
|
12 |
+
self.act = nn.ReLU()
|
13 |
+
self.residual = residual
|
14 |
+
|
15 |
+
def forward(self, x):
|
16 |
+
out = self.conv_block(x)
|
17 |
+
if self.residual:
|
18 |
+
out += x
|
19 |
+
return self.act(out)
|
20 |
+
|
21 |
+
class nonorm_Conv2d(nn.Module):
|
22 |
+
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
|
23 |
+
super().__init__(*args, **kwargs)
|
24 |
+
self.conv_block = nn.Sequential(
|
25 |
+
nn.Conv2d(cin, cout, kernel_size, stride, padding),
|
26 |
+
)
|
27 |
+
self.act = nn.LeakyReLU(0.01, inplace=True)
|
28 |
+
|
29 |
+
def forward(self, x):
|
30 |
+
out = self.conv_block(x)
|
31 |
+
return self.act(out)
|
32 |
+
|
33 |
+
class Conv2dTranspose(nn.Module):
|
34 |
+
def __init__(self, cin, cout, kernel_size, stride, padding, output_padding=0, *args, **kwargs):
|
35 |
+
super().__init__(*args, **kwargs)
|
36 |
+
self.conv_block = nn.Sequential(
|
37 |
+
nn.ConvTranspose2d(cin, cout, kernel_size, stride, padding, output_padding),
|
38 |
+
nn.BatchNorm2d(cout)
|
39 |
+
)
|
40 |
+
self.act = nn.ReLU()
|
41 |
+
|
42 |
+
def forward(self, x):
|
43 |
+
out = self.conv_block(x)
|
44 |
+
return self.act(out)
|
models/syncnet.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
from .conv import Conv2d
|
6 |
+
|
7 |
+
class SyncNet_color(nn.Module):
|
8 |
+
def __init__(self):
|
9 |
+
super(SyncNet_color, self).__init__()
|
10 |
+
|
11 |
+
self.face_encoder = nn.Sequential(
|
12 |
+
Conv2d(15, 32, kernel_size=(7, 7), stride=1, padding=3),
|
13 |
+
|
14 |
+
Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=1),
|
15 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
16 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
17 |
+
|
18 |
+
Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
|
19 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
20 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
21 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
22 |
+
|
23 |
+
Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
|
24 |
+
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
|
25 |
+
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
|
26 |
+
|
27 |
+
Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
|
28 |
+
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
|
29 |
+
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
|
30 |
+
|
31 |
+
Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
|
32 |
+
Conv2d(512, 512, kernel_size=3, stride=1, padding=0),
|
33 |
+
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
|
34 |
+
|
35 |
+
self.audio_encoder = nn.Sequential(
|
36 |
+
Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
|
37 |
+
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
|
38 |
+
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
|
39 |
+
|
40 |
+
Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
|
41 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
42 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
43 |
+
|
44 |
+
Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
|
45 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
46 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
47 |
+
|
48 |
+
Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
|
49 |
+
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
|
50 |
+
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
|
51 |
+
|
52 |
+
Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
|
53 |
+
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
|
54 |
+
|
55 |
+
def forward(self, audio_sequences, face_sequences): # audio_sequences := (B, dim, T)
|
56 |
+
face_embedding = self.face_encoder(face_sequences)
|
57 |
+
audio_embedding = self.audio_encoder(audio_sequences)
|
58 |
+
|
59 |
+
audio_embedding = audio_embedding.view(audio_embedding.size(0), -1)
|
60 |
+
face_embedding = face_embedding.view(face_embedding.size(0), -1)
|
61 |
+
|
62 |
+
audio_embedding = F.normalize(audio_embedding, p=2, dim=1)
|
63 |
+
face_embedding = F.normalize(face_embedding, p=2, dim=1)
|
64 |
+
|
65 |
+
|
66 |
+
return audio_embedding, face_embedding
|
models/wav2lip.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
import math
|
5 |
+
|
6 |
+
from .conv import Conv2dTranspose, Conv2d, nonorm_Conv2d
|
7 |
+
|
8 |
+
class Wav2Lip(nn.Module):
|
9 |
+
def __init__(self):
|
10 |
+
super(Wav2Lip, self).__init__()
|
11 |
+
|
12 |
+
self.face_encoder_blocks = nn.ModuleList([
|
13 |
+
nn.Sequential(Conv2d(6, 16, kernel_size=7, stride=1, padding=3)), # 96,96
|
14 |
+
|
15 |
+
nn.Sequential(Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # 48,48
|
16 |
+
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
|
17 |
+
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True)),
|
18 |
+
|
19 |
+
nn.Sequential(Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # 24,24
|
20 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
21 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
22 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True)),
|
23 |
+
|
24 |
+
nn.Sequential(Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # 12,12
|
25 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
26 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True)),
|
27 |
+
|
28 |
+
nn.Sequential(Conv2d(128, 256, kernel_size=3, stride=2, padding=1), # 6,6
|
29 |
+
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
|
30 |
+
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True)),
|
31 |
+
|
32 |
+
nn.Sequential(Conv2d(256, 512, kernel_size=3, stride=2, padding=1), # 3,3
|
33 |
+
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),),
|
34 |
+
|
35 |
+
nn.Sequential(Conv2d(512, 512, kernel_size=3, stride=1, padding=0), # 1, 1
|
36 |
+
Conv2d(512, 512, kernel_size=1, stride=1, padding=0)),])
|
37 |
+
|
38 |
+
self.audio_encoder = nn.Sequential(
|
39 |
+
Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
|
40 |
+
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
|
41 |
+
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
|
42 |
+
|
43 |
+
Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
|
44 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
45 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
46 |
+
|
47 |
+
Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
|
48 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
49 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
50 |
+
|
51 |
+
Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
|
52 |
+
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
|
53 |
+
|
54 |
+
Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
|
55 |
+
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
|
56 |
+
|
57 |
+
self.face_decoder_blocks = nn.ModuleList([
|
58 |
+
nn.Sequential(Conv2d(512, 512, kernel_size=1, stride=1, padding=0),),
|
59 |
+
|
60 |
+
nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=1, padding=0), # 3,3
|
61 |
+
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),),
|
62 |
+
|
63 |
+
nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=2, padding=1, output_padding=1),
|
64 |
+
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
|
65 |
+
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),), # 6, 6
|
66 |
+
|
67 |
+
nn.Sequential(Conv2dTranspose(768, 384, kernel_size=3, stride=2, padding=1, output_padding=1),
|
68 |
+
Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True),
|
69 |
+
Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True),), # 12, 12
|
70 |
+
|
71 |
+
nn.Sequential(Conv2dTranspose(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1),
|
72 |
+
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
|
73 |
+
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),), # 24, 24
|
74 |
+
|
75 |
+
nn.Sequential(Conv2dTranspose(320, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
|
76 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
77 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),), # 48, 48
|
78 |
+
|
79 |
+
nn.Sequential(Conv2dTranspose(160, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
|
80 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
81 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),),]) # 96,96
|
82 |
+
|
83 |
+
self.output_block = nn.Sequential(Conv2d(80, 32, kernel_size=3, stride=1, padding=1),
|
84 |
+
nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0),
|
85 |
+
nn.Sigmoid())
|
86 |
+
|
87 |
+
def forward(self, audio_sequences, face_sequences):
|
88 |
+
# audio_sequences = (B, T, 1, 80, 16)
|
89 |
+
B = audio_sequences.size(0)
|
90 |
+
|
91 |
+
input_dim_size = len(face_sequences.size())
|
92 |
+
if input_dim_size > 4:
|
93 |
+
audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
|
94 |
+
face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0)
|
95 |
+
|
96 |
+
audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1
|
97 |
+
|
98 |
+
feats = []
|
99 |
+
x = face_sequences
|
100 |
+
for f in self.face_encoder_blocks:
|
101 |
+
x = f(x)
|
102 |
+
feats.append(x)
|
103 |
+
|
104 |
+
x = audio_embedding
|
105 |
+
for f in self.face_decoder_blocks:
|
106 |
+
x = f(x)
|
107 |
+
try:
|
108 |
+
x = torch.cat((x, feats[-1]), dim=1)
|
109 |
+
except Exception as e:
|
110 |
+
print(x.size())
|
111 |
+
print(feats[-1].size())
|
112 |
+
raise e
|
113 |
+
|
114 |
+
feats.pop()
|
115 |
+
|
116 |
+
x = self.output_block(x)
|
117 |
+
|
118 |
+
if input_dim_size > 4:
|
119 |
+
x = torch.split(x, B, dim=0) # [(B, C, H, W)]
|
120 |
+
outputs = torch.stack(x, dim=2) # (B, C, T, H, W)
|
121 |
+
|
122 |
+
else:
|
123 |
+
outputs = x
|
124 |
+
|
125 |
+
return outputs
|
126 |
+
|
127 |
+
class Wav2Lip_disc_qual(nn.Module):
|
128 |
+
def __init__(self):
|
129 |
+
super(Wav2Lip_disc_qual, self).__init__()
|
130 |
+
|
131 |
+
self.face_encoder_blocks = nn.ModuleList([
|
132 |
+
nn.Sequential(nonorm_Conv2d(3, 32, kernel_size=7, stride=1, padding=3)), # 48,96
|
133 |
+
|
134 |
+
nn.Sequential(nonorm_Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=2), # 48,48
|
135 |
+
nonorm_Conv2d(64, 64, kernel_size=5, stride=1, padding=2)),
|
136 |
+
|
137 |
+
nn.Sequential(nonorm_Conv2d(64, 128, kernel_size=5, stride=2, padding=2), # 24,24
|
138 |
+
nonorm_Conv2d(128, 128, kernel_size=5, stride=1, padding=2)),
|
139 |
+
|
140 |
+
nn.Sequential(nonorm_Conv2d(128, 256, kernel_size=5, stride=2, padding=2), # 12,12
|
141 |
+
nonorm_Conv2d(256, 256, kernel_size=5, stride=1, padding=2)),
|
142 |
+
|
143 |
+
nn.Sequential(nonorm_Conv2d(256, 512, kernel_size=3, stride=2, padding=1), # 6,6
|
144 |
+
nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=1)),
|
145 |
+
|
146 |
+
nn.Sequential(nonorm_Conv2d(512, 512, kernel_size=3, stride=2, padding=1), # 3,3
|
147 |
+
nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=1),),
|
148 |
+
|
149 |
+
nn.Sequential(nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=0), # 1, 1
|
150 |
+
nonorm_Conv2d(512, 512, kernel_size=1, stride=1, padding=0)),])
|
151 |
+
|
152 |
+
self.binary_pred = nn.Sequential(nn.Conv2d(512, 1, kernel_size=1, stride=1, padding=0), nn.Sigmoid())
|
153 |
+
self.label_noise = .0
|
154 |
+
|
155 |
+
def get_lower_half(self, face_sequences):
|
156 |
+
return face_sequences[:, :, face_sequences.size(2)//2:]
|
157 |
+
|
158 |
+
def to_2d(self, face_sequences):
|
159 |
+
B = face_sequences.size(0)
|
160 |
+
face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0)
|
161 |
+
return face_sequences
|
162 |
+
|
163 |
+
def perceptual_forward(self, false_face_sequences):
|
164 |
+
false_face_sequences = self.to_2d(false_face_sequences)
|
165 |
+
false_face_sequences = self.get_lower_half(false_face_sequences)
|
166 |
+
|
167 |
+
false_feats = false_face_sequences
|
168 |
+
for f in self.face_encoder_blocks:
|
169 |
+
false_feats = f(false_feats)
|
170 |
+
|
171 |
+
false_pred_loss = F.binary_cross_entropy(self.binary_pred(false_feats).view(len(false_feats), -1),
|
172 |
+
torch.ones((len(false_feats), 1)).cuda())
|
173 |
+
|
174 |
+
return false_pred_loss
|
175 |
+
|
176 |
+
def forward(self, face_sequences):
|
177 |
+
face_sequences = self.to_2d(face_sequences)
|
178 |
+
face_sequences = self.get_lower_half(face_sequences)
|
179 |
+
|
180 |
+
x = face_sequences
|
181 |
+
for f in self.face_encoder_blocks:
|
182 |
+
x = f(x)
|
183 |
+
|
184 |
+
return self.binary_pred(x).view(len(x), -1)
|
predict.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Prediction interface for Cog ⚙️
|
2 |
+
# https://github.com/replicate/cog/blob/main/docs/python.md
|
3 |
+
import os
|
4 |
+
import subprocess
|
5 |
+
|
6 |
+
from cog import BasePredictor, Input, Path
|
7 |
+
|
8 |
+
import inference
|
9 |
+
|
10 |
+
from time import time
|
11 |
+
|
12 |
+
from functools import wraps
|
13 |
+
import torch
|
14 |
+
|
15 |
+
|
16 |
+
def make_mem_efficient(cls: BasePredictor):
|
17 |
+
if not torch.cuda.is_available():
|
18 |
+
return cls
|
19 |
+
|
20 |
+
old_setup = cls.setup
|
21 |
+
old_predict = cls.predict
|
22 |
+
|
23 |
+
@wraps(old_setup)
|
24 |
+
def new_setup(self, *args, **kwargs):
|
25 |
+
ret = old_setup(self, *args, **kwargs)
|
26 |
+
_move_to(self, "cpu")
|
27 |
+
return ret
|
28 |
+
|
29 |
+
@wraps(old_predict)
|
30 |
+
def new_predict(self, *args, **kwargs):
|
31 |
+
_move_to(self, "cuda")
|
32 |
+
try:
|
33 |
+
ret = old_predict(self, *args, **kwargs)
|
34 |
+
finally:
|
35 |
+
_move_to(self, "cpu")
|
36 |
+
return ret
|
37 |
+
|
38 |
+
cls.setup = new_setup
|
39 |
+
cls.predict = new_predict
|
40 |
+
|
41 |
+
return cls
|
42 |
+
|
43 |
+
|
44 |
+
def _move_to(self, device):
|
45 |
+
try:
|
46 |
+
self = self.cached_models
|
47 |
+
except AttributeError:
|
48 |
+
pass
|
49 |
+
for attr, value in vars(self).items():
|
50 |
+
try:
|
51 |
+
value = value.to(device)
|
52 |
+
except AttributeError:
|
53 |
+
pass
|
54 |
+
else:
|
55 |
+
print(f"Moving {self.__name__}.{attr} to {device}")
|
56 |
+
setattr(self, attr, value)
|
57 |
+
torch.cuda.empty_cache()
|
58 |
+
|
59 |
+
|
60 |
+
@make_mem_efficient
|
61 |
+
class Predictor(BasePredictor):
|
62 |
+
cached_models = inference
|
63 |
+
|
64 |
+
def setup(self):
|
65 |
+
inference.do_load("checkpoints/wav2lip_gan.pth")
|
66 |
+
|
67 |
+
def predict(
|
68 |
+
self,
|
69 |
+
face: Path = Input(description="video/image that contains faces to use"),
|
70 |
+
audio: Path = Input(description="video/audio file to use as raw audio source"),
|
71 |
+
pads: str = Input(
|
72 |
+
description="Padding for the detected face bounding box.\n"
|
73 |
+
"Please adjust to include chin at least\n"
|
74 |
+
'Format: "top bottom left right"',
|
75 |
+
default="0 10 0 0",
|
76 |
+
),
|
77 |
+
smooth: bool = Input(
|
78 |
+
description="Smooth face detections over a short temporal window",
|
79 |
+
default=True,
|
80 |
+
),
|
81 |
+
fps: float = Input(
|
82 |
+
description="Can be specified only if input is a static image",
|
83 |
+
default=25.0,
|
84 |
+
),
|
85 |
+
out_height: int = Input(
|
86 |
+
description="Output video height. Best results are obtained at 480 or 720",
|
87 |
+
default=480,
|
88 |
+
),
|
89 |
+
) -> Path:
|
90 |
+
try:
|
91 |
+
os.remove("results/result_voice.mp4")
|
92 |
+
except FileNotFoundError:
|
93 |
+
pass
|
94 |
+
|
95 |
+
face_ext = os.path.splitext(face)[-1]
|
96 |
+
if face_ext not in [".mp4", ".mov", ".png" , ".jpg" , ".jpeg" , ".gif", ".mkv", ".webp"]:
|
97 |
+
raise ValueError(f'Unsupported face format {face_ext!r}')
|
98 |
+
|
99 |
+
audio_ext = os.path.splitext(audio)[-1]
|
100 |
+
if audio_ext not in [".wav", ".mp3"]:
|
101 |
+
raise ValueError(f'Unsupported audio format {audio_ext!r}')
|
102 |
+
|
103 |
+
args = [
|
104 |
+
"--checkpoint_path", "checkpoints/wav2lip_gan.pth",
|
105 |
+
"--face", str(face),
|
106 |
+
"--audio", str(audio),
|
107 |
+
"--pads", *pads.split(" "),
|
108 |
+
"--fps", str(fps),
|
109 |
+
"--out_height", str(out_height),
|
110 |
+
]
|
111 |
+
if not smooth:
|
112 |
+
args += ["--nosmooth"]
|
113 |
+
|
114 |
+
print("-> run:", " ".join(args))
|
115 |
+
inference.args = inference.parser.parse_args(args)
|
116 |
+
|
117 |
+
s = time()
|
118 |
+
|
119 |
+
try:
|
120 |
+
inference.main()
|
121 |
+
except ValueError as e:
|
122 |
+
print('-> Encountered error, skipping lipsync:', e)
|
123 |
+
|
124 |
+
args = [
|
125 |
+
"ffmpeg", "-y",
|
126 |
+
# "-vsync", "0", "-hwaccel", "cuda", "-hwaccel_output_format", "cuda",
|
127 |
+
"-stream_loop", "-1",
|
128 |
+
"-i", str(face),
|
129 |
+
"-i", str(audio),
|
130 |
+
"-shortest",
|
131 |
+
"-fflags", "+shortest",
|
132 |
+
"-max_interleave_delta", "100M",
|
133 |
+
"-map", "0:v:0",
|
134 |
+
"-map", "1:a:0",
|
135 |
+
# "-c", "copy",
|
136 |
+
# "-c:v", "h264_nvenc",
|
137 |
+
"results/result_voice.mp4",
|
138 |
+
]
|
139 |
+
print("-> run:", " ".join(args))
|
140 |
+
print(subprocess.check_output(args, encoding="utf-8"))
|
141 |
+
|
142 |
+
print(time() - s)
|
143 |
+
|
144 |
+
return Path("results/result_voice.mp4")
|
preprocess.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
|
3 |
+
if sys.version_info[0] < 3 and sys.version_info[1] < 2:
|
4 |
+
raise Exception("Must be using >= Python 3.2")
|
5 |
+
|
6 |
+
from os import listdir, path
|
7 |
+
|
8 |
+
if not path.isfile('face_detection/detection/sfd/s3fd.pth'):
|
9 |
+
raise FileNotFoundError('Save the s3fd model to face_detection/detection/sfd/s3fd.pth \
|
10 |
+
before running this script!')
|
11 |
+
|
12 |
+
import multiprocessing as mp
|
13 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
14 |
+
import numpy as np
|
15 |
+
import argparse, os, cv2, traceback, subprocess
|
16 |
+
from tqdm import tqdm
|
17 |
+
from glob import glob
|
18 |
+
import audio
|
19 |
+
from hparams import hparams as hp
|
20 |
+
|
21 |
+
import face_detection
|
22 |
+
|
23 |
+
parser = argparse.ArgumentParser()
|
24 |
+
|
25 |
+
parser.add_argument('--ngpu', help='Number of GPUs across which to run in parallel', default=1, type=int)
|
26 |
+
parser.add_argument('--batch_size', help='Single GPU Face detection batch size', default=32, type=int)
|
27 |
+
parser.add_argument("--data_root", help="Root folder of the LRS2 dataset", required=True)
|
28 |
+
parser.add_argument("--preprocessed_root", help="Root folder of the preprocessed dataset", required=True)
|
29 |
+
|
30 |
+
args = parser.parse_args()
|
31 |
+
|
32 |
+
fa = [face_detection.FaceAlignment(face_detection.LandmarksType._2D, flip_input=False,
|
33 |
+
device='cuda:{}'.format(id)) for id in range(args.ngpu)]
|
34 |
+
|
35 |
+
template = 'ffmpeg -loglevel panic -y -i {} -strict -2 {}'
|
36 |
+
# template2 = 'ffmpeg -hide_banner -loglevel panic -threads 1 -y -i {} -async 1 -ac 1 -vn -acodec pcm_s16le -ar 16000 {}'
|
37 |
+
|
38 |
+
def process_video_file(vfile, args, gpu_id):
|
39 |
+
video_stream = cv2.VideoCapture(vfile)
|
40 |
+
|
41 |
+
frames = []
|
42 |
+
while 1:
|
43 |
+
still_reading, frame = video_stream.read()
|
44 |
+
if not still_reading:
|
45 |
+
video_stream.release()
|
46 |
+
break
|
47 |
+
frames.append(frame)
|
48 |
+
|
49 |
+
vidname = os.path.basename(vfile).split('.')[0]
|
50 |
+
dirname = vfile.split('/')[-2]
|
51 |
+
|
52 |
+
fulldir = path.join(args.preprocessed_root, dirname, vidname)
|
53 |
+
os.makedirs(fulldir, exist_ok=True)
|
54 |
+
|
55 |
+
batches = [frames[i:i + args.batch_size] for i in range(0, len(frames), args.batch_size)]
|
56 |
+
|
57 |
+
i = -1
|
58 |
+
for fb in batches:
|
59 |
+
preds = fa[gpu_id].get_detections_for_batch(np.asarray(fb))
|
60 |
+
|
61 |
+
for j, f in enumerate(preds):
|
62 |
+
i += 1
|
63 |
+
if f is None:
|
64 |
+
continue
|
65 |
+
|
66 |
+
x1, y1, x2, y2 = f
|
67 |
+
cv2.imwrite(path.join(fulldir, '{}.jpg'.format(i)), fb[j][y1:y2, x1:x2])
|
68 |
+
|
69 |
+
def process_audio_file(vfile, args):
|
70 |
+
vidname = os.path.basename(vfile).split('.')[0]
|
71 |
+
dirname = vfile.split('/')[-2]
|
72 |
+
|
73 |
+
fulldir = path.join(args.preprocessed_root, dirname, vidname)
|
74 |
+
os.makedirs(fulldir, exist_ok=True)
|
75 |
+
|
76 |
+
wavpath = path.join(fulldir, 'audio.wav')
|
77 |
+
|
78 |
+
command = template.format(vfile, wavpath)
|
79 |
+
subprocess.call(command, shell=True)
|
80 |
+
|
81 |
+
|
82 |
+
def mp_handler(job):
|
83 |
+
vfile, args, gpu_id = job
|
84 |
+
try:
|
85 |
+
process_video_file(vfile, args, gpu_id)
|
86 |
+
except KeyboardInterrupt:
|
87 |
+
exit(0)
|
88 |
+
except:
|
89 |
+
traceback.print_exc()
|
90 |
+
|
91 |
+
def main(args):
|
92 |
+
print('Started processing for {} with {} GPUs'.format(args.data_root, args.ngpu))
|
93 |
+
|
94 |
+
filelist = glob(path.join(args.data_root, '*/*.mp4'))
|
95 |
+
|
96 |
+
jobs = [(vfile, args, i%args.ngpu) for i, vfile in enumerate(filelist)]
|
97 |
+
p = ThreadPoolExecutor(args.ngpu)
|
98 |
+
futures = [p.submit(mp_handler, j) for j in jobs]
|
99 |
+
_ = [r.result() for r in tqdm(as_completed(futures), total=len(futures))]
|
100 |
+
|
101 |
+
print('Dumping audios...')
|
102 |
+
|
103 |
+
for vfile in tqdm(filelist):
|
104 |
+
try:
|
105 |
+
process_audio_file(vfile, args)
|
106 |
+
except KeyboardInterrupt:
|
107 |
+
exit(0)
|
108 |
+
except:
|
109 |
+
traceback.print_exc()
|
110 |
+
continue
|
111 |
+
|
112 |
+
if __name__ == '__main__':
|
113 |
+
main(args)
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy==1.23.4
|
2 |
+
librosa==0.7.0
|
3 |
+
opencv-python==4.6.0.66
|
4 |
+
--extra-index-url=https://download.pytorch.org/whl/cu116
|
5 |
+
torch==1.12.1+cu116
|
6 |
+
--extra-index-url=https://download.pytorch.org/whl/cu116
|
7 |
+
torchvision==0.13.1+cu116
|
8 |
+
tqdm==4.45.0
|
9 |
+
numba==0.48
|
10 |
+
mediapipe==0.8.11
|
requirementsCPU.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
librosa
|
2 |
+
numpy
|
3 |
+
opencv-contrib-python
|
4 |
+
opencv-python
|
5 |
+
-f https://download.pytorch.org/whl/torch_stable.html
|
6 |
+
torch
|
7 |
+
torchvision
|
8 |
+
tqdm
|
9 |
+
numba
|
requirements_colab.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy==1.23.4
|
2 |
+
librosa
|
3 |
+
opencv-python
|
4 |
+
torch
|
5 |
+
torchvision
|
6 |
+
tqdm
|
7 |
+
numba
|
results/README.md
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Generated results will be placed in this folder by default.
|
scripts/download_models.sh
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
set -ex
|
4 |
+
|
5 |
+
wget -c -O checkpoints/wav2lip_gan.pth 'https://iiitaphyd-my.sharepoint.com/personal/radrabha_m_research_iiit_ac_in/_layouts/15/download.aspx?share=EdjI7bZlgApMqsVoEUUXpLsBxqXbn5z8VTmoxp55YNDcIA'
|
6 |
+
wget -c -O checkpoints/mobilenet.pth 'https://github.com/elliottzheng/face-detection/releases/download/0.0.1/mobilenet0.25_Final.pth'
|
7 |
+
wget -c -O checkpoints/resnet50.pth 'https://github.com/elliottzheng/face-detection/releases/download/0.0.1/Resnet50_Final.pth'
|
scripts/run-dev.sh
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
NAME=wav2lip-dev
|
4 |
+
|
5 |
+
set -ex
|
6 |
+
|
7 |
+
docker build . -t $NAME
|
8 |
+
docker run -it --rm \
|
9 |
+
--name $NAME \
|
10 |
+
-v $PWD/checkpoints:/src/checkpoints \
|
11 |
+
-p 6001:5000 \
|
12 |
+
--gpus all \
|
13 |
+
$NAME
|