teja141290 commited on
Commit
5bd2d37
·
1 Parent(s): 412f263

Move app.py to repo root for Hugging Face Spaces

Browse files
Files changed (1) hide show
  1. app.py +101 -0
app.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from model.models import UNet
4
+ from scripts.test_functions import process_image, process_video
5
+
6
+ window_size = 512
7
+ stride = 256
8
+ steps = 18
9
+ frame_count = 0
10
+
11
+ def get_model():
12
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
13
+ unet_model = UNet().to(device)
14
+ unet_model.load_state_dict(torch.load("model/best_unet_model.pth", map_location=device))
15
+ unet_model.eval()
16
+ return unet_model
17
+
18
+ unet_model = get_model()
19
+
20
+ def block_img(image, source_age, target_age):
21
+ from PIL import Image as PILImage
22
+ import numpy as np
23
+ if isinstance(image, str):
24
+ image = PILImage.open(image).convert('RGB')
25
+ elif isinstance(image, np.ndarray) and image.dtype == object:
26
+ image = image.astype(np.uint8)
27
+ return process_image(unet_model, image, video=False, source_age=source_age,
28
+ target_age=target_age, window_size=window_size, stride=stride)
29
+
30
+ def block_img_vid(image, source_age):
31
+ from PIL import Image as PILImage
32
+ import numpy as np
33
+ if isinstance(image, str):
34
+ image = PILImage.open(image).convert('RGB')
35
+ elif isinstance(image, np.ndarray) and image.dtype == object:
36
+ image = image.astype(np.uint8)
37
+ return process_image(unet_model, image, video=True, source_age=source_age,
38
+ target_age=0, window_size=window_size, stride=stride, steps=steps)
39
+
40
+ def block_vid(video_path, source_age, target_age):
41
+ return process_video(unet_model, video_path, source_age, target_age,
42
+ window_size=window_size, stride=stride, frame_count=frame_count)
43
+
44
+ demo_img = gr.Interface(
45
+ fn=block_img,
46
+ inputs=[
47
+ gr.Image(type="pil"),
48
+ gr.Slider(10, 90, value=20, step=1, label="Current age", info="Choose your current age"),
49
+ gr.Slider(10, 90, value=80, step=1, label="Target age", info="Choose the age you want to become")
50
+ ],
51
+ outputs="image",
52
+ examples=[
53
+ ['assets/gradio_example_images/1.png', 20, 80],
54
+ ['assets/gradio_example_images/2.png', 75, 40],
55
+ ['assets/gradio_example_images/3.png', 30, 70],
56
+ ['assets/gradio_example_images/4.png', 22, 60],
57
+ ['assets/gradio_example_images/5.png', 28, 75],
58
+ ['assets/gradio_example_images/6.png', 35, 15]
59
+ ],
60
+ description="Input an image of a person and age them from the source age to the target age."
61
+ )
62
+
63
+ demo_img_vid = gr.Interface(
64
+ fn=block_img_vid,
65
+ inputs=[
66
+ gr.Image(type="pil"),
67
+ gr.Slider(10, 90, value=20, step=1, label="Current age", info="Choose your current age"),
68
+ ],
69
+ outputs=gr.Video(),
70
+ examples=[
71
+ ['assets/gradio_example_images/1.png', 20],
72
+ ['assets/gradio_example_images/2.png', 75],
73
+ ['assets/gradio_example_images/3.png', 30],
74
+ ['assets/gradio_example_images/4.png', 22],
75
+ ['assets/gradio_example_images/5.png', 28],
76
+ ['assets/gradio_example_images/6.png', 35]
77
+ ],
78
+ description="Input an image of a person and a video will be returned of the person at different ages."
79
+ )
80
+
81
+ demo_vid = gr.Interface(
82
+ fn=block_vid,
83
+ inputs=[
84
+ gr.Video(),
85
+ gr.Slider(10, 90, value=20, step=1, label="Current age", info="Choose your current age"),
86
+ gr.Slider(10, 90, value=80, step=1, label="Target age", info="Choose the age you want to become")
87
+ ],
88
+ outputs=gr.Video(),
89
+ examples=[
90
+ ['assets/gradio_example_images/orig.mp4', 35, 60],
91
+ ],
92
+ description="Input a video of a person, and it will be aged frame-by-frame."
93
+ )
94
+
95
+ demo = gr.TabbedInterface([demo_img, demo_img_vid, demo_vid],
96
+ tab_names=['Image inference demo', 'Image animation demo', 'Video inference demo'],
97
+ title="Face Re-Aging Demo",
98
+ )
99
+
100
+ if __name__ == "__main__":
101
+ demo.launch()