hysts HF staff commited on
Commit
8dbbb46
β€’
1 Parent(s): 6614b1a
Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +32 -43
  3. requirements.txt +76 -25
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: πŸ‘
4
  colorFrom: indigo
5
  colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 3.0.19
8
  app_file: app.py
9
  pinned: false
10
  ---
 
4
  colorFrom: indigo
5
  colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 3.48.0
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py CHANGED
@@ -1,15 +1,18 @@
1
- import gradio as gr
2
  import os
3
- import shutil
 
 
4
  import torch
5
  from PIL import Image
6
- import argparse
7
- import pathlib
8
 
9
- os.system("git clone https://github.com/yoyo-nb/Thin-Plate-Spline-Motion-Model")
10
- os.chdir("Thin-Plate-Spline-Motion-Model")
11
- os.system("mkdir checkpoints")
12
- os.system("gdown 1-CKOjv_y_TzNe-dwQsjjeVxJUuyBAb5X -O checkpoints/vox.pth.tar")
 
 
 
 
13
 
14
 
15
 
@@ -40,24 +43,21 @@ def update_style_image(style_name: str) -> dict:
40
  return gr.Markdown.update(value=text)
41
 
42
 
43
- def set_example_image(example: list) -> dict:
44
- return gr.Image.update(value=example[0])
 
45
 
46
- def set_example_video(example: list) -> dict:
47
- return gr.Video.update(value=example[0])
 
 
 
 
48
 
49
- def inference(img,vid):
50
- if not os.path.exists('temp'):
51
- os.system('mkdir temp')
52
-
53
- img.save("temp/image.jpg", "JPEG")
54
- os.system(f"python demo.py --config config/vox-256.yaml --checkpoint ./checkpoints/vox.pth.tar --source_image 'temp/image.jpg' --driving_video {vid} --result_video './temp/result.mp4' --cpu")
55
- return './temp/result.mp4'
56
-
57
 
58
 
59
  def main():
60
- with gr.Blocks(theme="huggingface", css='style.css') as demo:
61
  gr.Markdown(title)
62
  gr.Markdown(DESCRIPTION)
63
 
@@ -71,12 +71,11 @@ def main():
71
  with gr.Row():
72
  input_image = gr.Image(label='Input Image',
73
  type="pil")
74
-
75
  with gr.Row():
76
  paths = sorted(pathlib.Path('assets').glob('*.png'))
77
- example_images = gr.Dataset(components=[input_image],
78
- samples=[[path.as_posix()]
79
- for path in paths])
80
 
81
  with gr.Box():
82
  gr.Markdown('''## Step 2 (Select Driving Video)
@@ -86,17 +85,16 @@ def main():
86
  with gr.Column():
87
  with gr.Row():
88
  driving_video = gr.Video(label='Driving Video',
89
- format="mp4")
90
 
91
  with gr.Row():
92
  paths = sorted(pathlib.Path('assets').glob('*.mp4'))
93
- example_video = gr.Dataset(components=[driving_video],
94
- samples=[[path.as_posix()]
95
- for path in paths])
96
 
97
  with gr.Box():
98
  gr.Markdown('''## Step 3 (Generate Animated Image based on the Video)
99
- - Hit the **Generate** button. (Note: As it runs on the CPU, it takes ~ 3 minutes to generate final results.)
100
  ''')
101
  with gr.Row():
102
  with gr.Column():
@@ -104,7 +102,7 @@ def main():
104
  generate_button = gr.Button('Generate')
105
 
106
  with gr.Column():
107
- result = gr.Video(type="file", label="Output")
108
  gr.Markdown(FOOTER)
109
  generate_button.click(fn=inference,
110
  inputs=[
@@ -112,17 +110,8 @@ def main():
112
  driving_video
113
  ],
114
  outputs=result)
115
- example_images.click(fn=set_example_image,
116
- inputs=example_images,
117
- outputs=example_images.components)
118
- example_video.click(fn=set_example_video,
119
- inputs=example_video,
120
- outputs=example_video.components)
121
-
122
- demo.launch(
123
- enable_queue=True,
124
- debug=True
125
- )
126
 
127
  if __name__ == '__main__':
128
- main()
 
 
1
  import os
2
+ import pathlib
3
+
4
+ import gradio as gr
5
  import torch
6
  from PIL import Image
 
 
7
 
8
+ repo_dir = pathlib.Path("Thin-Plate-Spline-Motion-Model").absolute()
9
+ if not repo_dir.exists():
10
+ os.system("git clone https://github.com/yoyo-nb/Thin-Plate-Spline-Motion-Model")
11
+ os.chdir(repo_dir.name)
12
+ if not (repo_dir / "checkpoints").exists():
13
+ os.system("mkdir checkpoints")
14
+ if not (repo_dir / "checkpoints/vox.pth.tar").exists():
15
+ os.system("gdown 1-CKOjv_y_TzNe-dwQsjjeVxJUuyBAb5X -O checkpoints/vox.pth.tar")
16
 
17
 
18
 
 
43
  return gr.Markdown.update(value=text)
44
 
45
 
46
+ def inference(img, vid):
47
+ if not os.path.exists('temp'):
48
+ os.system('mkdir temp')
49
 
50
+ img.save("temp/image.jpg", "JPEG")
51
+ if torch.cuda.is_available():
52
+ os.system(f"python demo.py --config config/vox-256.yaml --checkpoint ./checkpoints/vox.pth.tar --source_image 'temp/image.jpg' --driving_video {vid} --result_video './temp/result.mp4'")
53
+ else:
54
+ os.system(f"python demo.py --config config/vox-256.yaml --checkpoint ./checkpoints/vox.pth.tar --source_image 'temp/image.jpg' --driving_video {vid} --result_video './temp/result.mp4' --cpu")
55
+ return './temp/result.mp4'
56
 
 
 
 
 
 
 
 
 
57
 
58
 
59
  def main():
60
+ with gr.Blocks(css='style.css') as demo:
61
  gr.Markdown(title)
62
  gr.Markdown(DESCRIPTION)
63
 
 
71
  with gr.Row():
72
  input_image = gr.Image(label='Input Image',
73
  type="pil")
74
+
75
  with gr.Row():
76
  paths = sorted(pathlib.Path('assets').glob('*.png'))
77
+ gr.Examples(inputs=[input_image],
78
+ examples=[[path.as_posix()] for path in paths])
 
79
 
80
  with gr.Box():
81
  gr.Markdown('''## Step 2 (Select Driving Video)
 
85
  with gr.Column():
86
  with gr.Row():
87
  driving_video = gr.Video(label='Driving Video',
88
+ format="mp4")
89
 
90
  with gr.Row():
91
  paths = sorted(pathlib.Path('assets').glob('*.mp4'))
92
+ gr.Examples(inputs=[driving_video],
93
+ examples=[[path.as_posix()] for path in paths])
 
94
 
95
  with gr.Box():
96
  gr.Markdown('''## Step 3 (Generate Animated Image based on the Video)
97
+ - Hit the **Generate** button. (Note: On cpu-basic, it takes ~ 10 minutes to generate final results.)
98
  ''')
99
  with gr.Row():
100
  with gr.Column():
 
102
  generate_button = gr.Button('Generate')
103
 
104
  with gr.Column():
105
+ result = gr.Video(label="Output")
106
  gr.Markdown(FOOTER)
107
  generate_button.click(fn=inference,
108
  inputs=[
 
110
  driving_video
111
  ],
112
  outputs=result)
113
+
114
+ demo.queue(max_size=10).launch()
 
 
 
 
 
 
 
 
 
115
 
116
  if __name__ == '__main__':
117
+ main()
requirements.txt CHANGED
@@ -1,27 +1,78 @@
1
- cffi==1.14.6
2
- cycler==0.10.0
3
- decorator==5.1.0
4
- face-alignment==1.3.5
5
- imageio==2.9.0
6
- imageio-ffmpeg==0.4.5
7
- kiwisolver==1.3.2
8
- matplotlib==3.4.3
9
- networkx==2.6.3
10
- numpy==1.20.3
11
- pandas==1.3.3
12
- Pillow
13
- pycparser==2.20
14
- pyparsing==2.4.7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  python-dateutil==2.8.2
16
- pytz==2021.1
17
- PyWavelets==1.1.1
18
- PyYAML
19
- scikit-image
20
- scikit-learn
21
- scipy
 
 
 
 
22
  six==1.16.0
23
- torch==1.11.0
24
- torchvision==0.12.0
25
- tqdm==4.62.3
26
- gradio
27
- gdown
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ altair==5.1.2
3
+ annotated-types==0.6.0
4
+ anyio==3.7.1
5
+ attrs==23.1.0
6
+ beautifulsoup4==4.12.2
7
+ certifi==2023.7.22
8
+ charset-normalizer==3.3.0
9
+ click==8.1.7
10
+ cmake==3.27.7
11
+ contourpy==1.1.1
12
+ cycler==0.12.1
13
+ exceptiongroup==1.1.3
14
+ fastapi==0.103.2
15
+ ffmpy==0.3.1
16
+ filelock==3.12.4
17
+ fonttools==4.43.1
18
+ fsspec==2023.9.2
19
+ gdown==4.7.1
20
+ gradio==3.48.0
21
+ gradio_client==0.6.1
22
+ h11==0.14.0
23
+ httpcore==0.18.0
24
+ httpx==0.25.0
25
+ huggingface-hub==0.18.0
26
+ idna==3.4
27
+ imageio==2.31.5
28
+ imageio-ffmpeg==0.4.9
29
+ importlib-resources==6.1.0
30
+ Jinja2==3.1.2
31
+ jsonschema==4.19.1
32
+ jsonschema-specifications==2023.7.1
33
+ kiwisolver==1.4.5
34
+ lazy_loader==0.3
35
+ lit==17.0.3
36
+ MarkupSafe==2.1.3
37
+ matplotlib==3.8.0
38
+ mpmath==1.3.0
39
+ networkx==3.1
40
+ numpy==1.25.2
41
+ orjson==3.9.9
42
+ packaging==23.2
43
+ pandas==2.1.1
44
+ Pillow==10.1.0
45
+ psutil==5.9.6
46
+ pydantic==2.4.2
47
+ pydantic_core==2.10.1
48
+ pydub==0.25.1
49
+ pyparsing==3.1.1
50
+ PySocks==1.7.1
51
  python-dateutil==2.8.2
52
+ python-multipart==0.0.6
53
+ pytz==2023.3.post1
54
+ PyYAML==6.0.1
55
+ referencing==0.30.2
56
+ requests==2.31.0
57
+ rpds-py==0.10.6
58
+ scikit-image==0.22.0
59
+ scipy==1.11.3
60
+ semantic-version==2.10.0
61
+ setuptools-scm==8.0.4
62
  six==1.16.0
63
+ sniffio==1.3.0
64
+ soupsieve==2.5
65
+ starlette==0.27.0
66
+ sympy==1.12
67
+ tifffile==2023.9.26
68
+ tomli==2.0.1
69
+ toolz==0.12.0
70
+ torch==2.0.0
71
+ torchvision==0.15.1
72
+ tqdm==4.66.1
73
+ triton==2.0.0
74
+ typing_extensions==4.8.0
75
+ tzdata==2023.3
76
+ urllib3==2.0.7
77
+ uvicorn==0.23.2
78
+ websockets==11.0.3