Spaces:
Runtime error
Runtime error
Andrei Boiarov
commited on
Commit
·
581506e
1
Parent(s):
aa94b73
Update app filew
Browse files
app.py
CHANGED
@@ -1,117 +1,115 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
#
|
6 |
-
# import gradio as gr
|
7 |
-
#
|
8 |
-
# feature_extractor = ViTFeatureExtractor.from_pretrained('andrewbo29/vit-mae-base-formula1')
|
9 |
-
# model = ViTMAEForPreTraining.from_pretrained('andrewbo29/vit-mae-base-formula1')
|
10 |
-
#
|
11 |
-
# imagenet_mean = np.array(feature_extractor.image_mean)
|
12 |
-
# imagenet_std = np.array(feature_extractor.image_std)
|
13 |
-
#
|
14 |
-
#
|
15 |
-
# def prep_image(image):
|
16 |
-
# return torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int().cpu().numpy()
|
17 |
-
#
|
18 |
-
#
|
19 |
-
# def reconstruct(img):
|
20 |
-
# image = Image.fromarray(img)
|
21 |
-
# pixel_values = feature_extractor(image, return_tensors='pt').pixel_values
|
22 |
-
#
|
23 |
-
# outputs = model(pixel_values)
|
24 |
-
# y = model.unpatchify(outputs.logits)
|
25 |
-
# y = torch.einsum('nchw->nhwc', y).detach().cpu()
|
26 |
-
#
|
27 |
-
# # visualize the mask
|
28 |
-
# mask = outputs.mask.detach()
|
29 |
-
# mask = mask.unsqueeze(-1).repeat(1, 1, model.config.patch_size ** 2 * 3) # (N, H*W, p*p*3)
|
30 |
-
# mask = model.unpatchify(mask) # 1 is removing, 0 is keeping
|
31 |
-
# mask = torch.einsum('nchw->nhwc', mask).detach().cpu()
|
32 |
-
#
|
33 |
-
# x = torch.einsum('nchw->nhwc', pixel_values).detach().cpu()
|
34 |
-
#
|
35 |
-
# # masked image
|
36 |
-
# im_masked = x * (1 - mask)
|
37 |
-
#
|
38 |
-
# # MAE reconstruction pasted with visible patches
|
39 |
-
# im_paste = x * (1 - mask) + y * mask
|
40 |
-
#
|
41 |
-
# out_masked = prep_image(im_masked[0])
|
42 |
-
# out_rec = prep_image(y[0])
|
43 |
-
# out_rec_vis = prep_image(im_paste[0])
|
44 |
-
#
|
45 |
-
# # out_masked, out_rec, out_rec_vis = img, img, img
|
46 |
-
#
|
47 |
-
# return [(out_masked, 'masked'), (out_rec, 'reconstruction'), (out_rec_vis, 'reconstruction + visible')]
|
48 |
-
# # return [(img, '1')]
|
49 |
-
#
|
50 |
-
#
|
51 |
-
# with gr.Blocks() as demo:
|
52 |
-
# with gr.Column(variant="panel"):
|
53 |
-
# with gr.Row():
|
54 |
-
# img = gr.Image(
|
55 |
-
# label="Enter your prompt",
|
56 |
-
# container=False,
|
57 |
-
# )
|
58 |
-
# btn = gr.Button("Generate image", scale=0)
|
59 |
-
#
|
60 |
-
# # gallery = gr.Gallery(
|
61 |
-
# # label="Generated images", show_label=False, elem_id="gallery"
|
62 |
-
# # , columns=[3], rows=[1], height='auto', container=True)
|
63 |
-
#
|
64 |
-
# gallery = gr.Gallery(columns=3,
|
65 |
-
# rows=1,
|
66 |
-
# height='800px',
|
67 |
-
# object_fit='none')
|
68 |
-
#
|
69 |
-
# btn.click(reconstruct, img, gallery)
|
70 |
-
#
|
71 |
-
# if __name__ == "__main__":
|
72 |
-
# demo.launch()
|
73 |
-
|
74 |
-
# This demo needs to be run from the repo folder.
|
75 |
-
# python demo/fake_gan/run.py
|
76 |
-
import random
|
77 |
|
78 |
import gradio as gr
|
79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
|
96 |
|
97 |
with gr.Blocks() as demo:
|
98 |
with gr.Column(variant="panel"):
|
99 |
with gr.Row():
|
100 |
-
|
101 |
label="Enter your prompt",
|
102 |
-
max_lines=1,
|
103 |
-
placeholder="Enter your prompt",
|
104 |
container=False,
|
105 |
)
|
106 |
btn = gr.Button("Generate image", scale=0)
|
107 |
|
108 |
-
gallery = gr.Gallery(
|
109 |
-
|
110 |
-
, columns=[
|
|
|
|
|
|
|
|
|
|
|
111 |
|
112 |
-
btn.click(
|
113 |
|
114 |
if __name__ == "__main__":
|
115 |
demo.launch()
|
116 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
|
|
|
1 |
+
from transformers import ViTFeatureExtractor, ViTMAEForPreTraining
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from PIL import Image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
import gradio as gr
|
7 |
|
8 |
+
feature_extractor = ViTFeatureExtractor.from_pretrained('andrewbo29/vit-mae-base-formula1')
|
9 |
+
model = ViTMAEForPreTraining.from_pretrained('andrewbo29/vit-mae-base-formula1')
|
10 |
+
|
11 |
+
imagenet_mean = np.array(feature_extractor.image_mean)
|
12 |
+
imagenet_std = np.array(feature_extractor.image_std)
|
13 |
+
|
14 |
+
|
15 |
+
def prep_image(image):
|
16 |
+
return torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int().cpu().numpy()
|
17 |
|
18 |
+
|
19 |
+
def reconstruct(img):
|
20 |
+
# image = Image.fromarray(img)
|
21 |
+
# pixel_values = feature_extractor(image, return_tensors='pt').pixel_values
|
22 |
+
#
|
23 |
+
# outputs = model(pixel_values)
|
24 |
+
# y = model.unpatchify(outputs.logits)
|
25 |
+
# y = torch.einsum('nchw->nhwc', y).detach().cpu()
|
26 |
+
#
|
27 |
+
# # visualize the mask
|
28 |
+
# mask = outputs.mask.detach()
|
29 |
+
# mask = mask.unsqueeze(-1).repeat(1, 1, model.config.patch_size ** 2 * 3) # (N, H*W, p*p*3)
|
30 |
+
# mask = model.unpatchify(mask) # 1 is removing, 0 is keeping
|
31 |
+
# mask = torch.einsum('nchw->nhwc', mask).detach().cpu()
|
32 |
+
#
|
33 |
+
# x = torch.einsum('nchw->nhwc', pixel_values).detach().cpu()
|
34 |
+
#
|
35 |
+
# # masked image
|
36 |
+
# im_masked = x * (1 - mask)
|
37 |
+
#
|
38 |
+
# # MAE reconstruction pasted with visible patches
|
39 |
+
# im_paste = x * (1 - mask) + y * mask
|
40 |
+
#
|
41 |
+
# out_masked = prep_image(im_masked[0])
|
42 |
+
# out_rec = prep_image(y[0])
|
43 |
+
# out_rec_vis = prep_image(im_paste[0])
|
44 |
+
#
|
45 |
+
# # out_masked, out_rec, out_rec_vis = img, img, img
|
46 |
+
#
|
47 |
+
# return [(out_masked, 'masked'), (out_rec, 'reconstruction'), (out_rec_vis, 'reconstruction + visible')]
|
48 |
+
return [(img, 'label 1')]
|
49 |
|
50 |
|
51 |
with gr.Blocks() as demo:
|
52 |
with gr.Column(variant="panel"):
|
53 |
with gr.Row():
|
54 |
+
img = gr.Image(
|
55 |
label="Enter your prompt",
|
|
|
|
|
56 |
container=False,
|
57 |
)
|
58 |
btn = gr.Button("Generate image", scale=0)
|
59 |
|
60 |
+
# gallery = gr.Gallery(
|
61 |
+
# label="Generated images", show_label=False, elem_id="gallery"
|
62 |
+
# , columns=[3], rows=[1], height='auto', container=True)
|
63 |
+
|
64 |
+
gallery = gr.Gallery(columns=1,
|
65 |
+
rows=1,
|
66 |
+
height='800px',
|
67 |
+
object_fit='none')
|
68 |
|
69 |
+
btn.click(reconstruct, img, gallery)
|
70 |
|
71 |
if __name__ == "__main__":
|
72 |
demo.launch()
|
73 |
|
74 |
+
# import random
|
75 |
+
#
|
76 |
+
# import gradio as gr
|
77 |
+
#
|
78 |
+
#
|
79 |
+
# def fake_gan():
|
80 |
+
# images = [
|
81 |
+
# (random.choice(
|
82 |
+
# [
|
83 |
+
# "https://images.unsplash.com/photo-1507003211169-0a1dd7228f2d?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=387&q=80",
|
84 |
+
# "https://images.unsplash.com/photo-1554151228-14d9def656e4?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=386&q=80",
|
85 |
+
# "https://images.unsplash.com/photo-1542909168-82c3e7fdca5c?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxzZWFyY2h8MXx8aHVtYW4lMjBmYWNlfGVufDB8fDB8fA%3D%3D&w=1000&q=80",
|
86 |
+
# "https://images.unsplash.com/photo-1546456073-92b9f0a8d413?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=387&q=80",
|
87 |
+
# "https://images.unsplash.com/photo-1601412436009-d964bd02edbc?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=464&q=80",
|
88 |
+
# ]
|
89 |
+
# ), f"label {i}" if i != 0 else "label" * 50)
|
90 |
+
# for i in range(3)
|
91 |
+
# ]
|
92 |
+
# return images
|
93 |
+
#
|
94 |
+
#
|
95 |
+
# with gr.Blocks() as demo:
|
96 |
+
# with gr.Column(variant="panel"):
|
97 |
+
# with gr.Row():
|
98 |
+
# text = gr.Textbox(
|
99 |
+
# label="Enter your prompt",
|
100 |
+
# max_lines=1,
|
101 |
+
# placeholder="Enter your prompt",
|
102 |
+
# container=False,
|
103 |
+
# )
|
104 |
+
# btn = gr.Button("Generate image", scale=0)
|
105 |
+
#
|
106 |
+
# gallery = gr.Gallery(
|
107 |
+
# label="Generated images", show_label=False, elem_id="gallery"
|
108 |
+
# , columns=[2], rows=[2], object_fit="contain", height="auto")
|
109 |
+
#
|
110 |
+
# btn.click(fake_gan, None, gallery)
|
111 |
+
#
|
112 |
+
# if __name__ == "__main__":
|
113 |
+
# demo.launch()
|
114 |
+
|
115 |
|