Spaces:
Running
Running
Josh Brown Kramer
commited on
Commit
·
a56695f
1
Parent(s):
d6c0ea5
Don't include gradio
Browse files- .gitignore +3 -0
- app.py +31 -12
- faceparsing.py +54 -0
- requirements.txt +3 -1
.gitignore
CHANGED
@@ -176,4 +176,7 @@ pyrightconfig.json
|
|
176 |
# VSCode
|
177 |
.vscode/
|
178 |
|
|
|
|
|
|
|
179 |
# End of https://www.toptal.com/developers/gitignore/api/python
|
|
|
176 |
# VSCode
|
177 |
.vscode/
|
178 |
|
179 |
+
# Gradio
|
180 |
+
.gradio/
|
181 |
+
|
182 |
# End of https://www.toptal.com/developers/gitignore/api/python
|
app.py
CHANGED
@@ -2,6 +2,10 @@ import gradio as gr
|
|
2 |
import zombie
|
3 |
from huggingface_hub import hf_hub_download
|
4 |
import onnxruntime as ort
|
|
|
|
|
|
|
|
|
5 |
# import torch
|
6 |
# from your_pix2pixhd_code import YourPix2PixHDModel, load_image, tensor2im # Adapt these imports
|
7 |
|
@@ -30,27 +34,42 @@ ort_session = ort.InferenceSession(model_path, providers=['CUDAExecutionProvider
|
|
30 |
|
31 |
# # return output_image
|
32 |
|
33 |
-
def predict(input_image):
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
# --- 3. Create the Gradio Interface ---
|
42 |
-
title = "
|
43 |
description = "Upload an image to see the pix2pixHD model in action."
|
44 |
-
article = "<p style='text-align: center'>Model based on the <a href='https://github.com/NVIDIA/pix2pixHD' target='_blank'>pix2pixHD repository</a
|
|
|
45 |
|
46 |
demo = gr.Interface(
|
47 |
fn=predict,
|
48 |
-
inputs=
|
|
|
|
|
|
|
49 |
outputs=gr.Image(type="pil", label="Output Image"),
|
50 |
title=title,
|
51 |
description=description,
|
52 |
article=article,
|
53 |
)
|
54 |
|
55 |
-
demo.launch()
|
56 |
-
|
|
|
2 |
import zombie
|
3 |
from huggingface_hub import hf_hub_download
|
4 |
import onnxruntime as ort
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image
|
7 |
+
from faceparsing import get_face_mask
|
8 |
+
|
9 |
# import torch
|
10 |
# from your_pix2pixhd_code import YourPix2PixHDModel, load_image, tensor2im # Adapt these imports
|
11 |
|
|
|
34 |
|
35 |
# # return output_image
|
36 |
|
37 |
+
def predict(input_image, mode):
|
38 |
+
if mode == "Classic":
|
39 |
+
# Use the transition_onnx function for side-by-side comparison
|
40 |
+
zombie_image = zombie.transition_onnx(input_image, ort_session)
|
41 |
+
if zombie_image is None:
|
42 |
+
return "No face found"
|
43 |
+
return zombie_image
|
44 |
+
elif mode == "In Place":
|
45 |
+
# Use the make_faces_zombie_from_array function for in-place transformation
|
46 |
+
#zombie_image = zombie.make_faces_zombie_from_array(im_array, None, ort_session)
|
47 |
+
#if zombie_image is None:
|
48 |
+
# return "No face found"
|
49 |
+
#return zombie_image
|
50 |
+
face_mask = get_face_mask(input_image)
|
51 |
+
return face_mask
|
52 |
+
|
53 |
+
else:
|
54 |
+
return "Invalid mode selected"
|
55 |
|
56 |
# --- 3. Create the Gradio Interface ---
|
57 |
+
title = "Make Me A Zombie"
|
58 |
description = "Upload an image to see the pix2pixHD model in action."
|
59 |
+
article = """<p style='text-align: center'>Model based on the <a href='https://github.com/NVIDIA/pix2pixHD' target='_blank'>pix2pixHD repository</a>.
|
60 |
+
More details at <a href='https://makemeazombie.com' target='_blank'>makemeazombie.com</a>.</p>"""
|
61 |
|
62 |
demo = gr.Interface(
|
63 |
fn=predict,
|
64 |
+
inputs=[
|
65 |
+
gr.Image(type="pil", label="Input Image"),
|
66 |
+
gr.Dropdown(choices=["Classic", "In Place"], value="Classic", label="Mode")
|
67 |
+
],
|
68 |
outputs=gr.Image(type="pil", label="Output Image"),
|
69 |
title=title,
|
70 |
description=description,
|
71 |
article=article,
|
72 |
)
|
73 |
|
74 |
+
#demo.launch()
|
75 |
+
demo.launch(debug=True)
|
faceparsing.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
|
4 |
+
|
5 |
+
from PIL import Image
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
|
8 |
+
# convenience expression for automatically determining device
|
9 |
+
device = (
|
10 |
+
"cuda"
|
11 |
+
# Device for NVIDIA or AMD GPUs
|
12 |
+
if torch.cuda.is_available()
|
13 |
+
else "mps"
|
14 |
+
# Device for Apple Silicon (Metal Performance Shaders)
|
15 |
+
if torch.backends.mps.is_available()
|
16 |
+
else "cpu"
|
17 |
+
)
|
18 |
+
|
19 |
+
# load models
|
20 |
+
image_processor = SegformerImageProcessor.from_pretrained("jonathandinu/face-parsing")
|
21 |
+
model = SegformerForSemanticSegmentation.from_pretrained("jonathandinu/face-parsing")
|
22 |
+
model.to(device)
|
23 |
+
|
24 |
+
def get_face_mask(image):
|
25 |
+
# run inference on image
|
26 |
+
inputs = image_processor(images=image, return_tensors="pt").to(device)
|
27 |
+
outputs = model(**inputs)
|
28 |
+
logits = outputs.logits # shape (batch_size, num_labels, ~height/4, ~width/4)
|
29 |
+
|
30 |
+
# resize output to match input image dimensions
|
31 |
+
upsampled_logits = nn.functional.interpolate(logits,
|
32 |
+
size=image.size[::-1], # H x W
|
33 |
+
mode='bilinear',
|
34 |
+
align_corners=False)
|
35 |
+
|
36 |
+
# get label masks
|
37 |
+
labels = upsampled_logits.argmax(dim=1)[0]
|
38 |
+
|
39 |
+
# move to CPU to visualize in matplotlib
|
40 |
+
labels_viz = labels.cpu().numpy()
|
41 |
+
|
42 |
+
#Map to something more colorful. Use a color map to map the labels to a color.
|
43 |
+
#Create a color map for colors 0 through 18
|
44 |
+
color_map = plt.get_cmap('tab20')
|
45 |
+
#Map the labels to a color
|
46 |
+
colors = color_map(labels_viz)
|
47 |
+
|
48 |
+
#Convert to PIL Image
|
49 |
+
colors_pil = Image.fromarray((colors * 255).astype(np.uint8))
|
50 |
+
|
51 |
+
|
52 |
+
return labels_viz
|
53 |
+
|
54 |
+
|
requirements.txt
CHANGED
@@ -2,4 +2,6 @@ gradio
|
|
2 |
onnxruntime-gpu
|
3 |
opencv-python
|
4 |
numpy
|
5 |
-
mediapipe
|
|
|
|
|
|
2 |
onnxruntime-gpu
|
3 |
opencv-python
|
4 |
numpy
|
5 |
+
mediapipe
|
6 |
+
torch
|
7 |
+
transformers
|