Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Upload kornia_aug.py
Browse files- kornia_aug.py +142 -0
kornia_aug.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import kornia
|
3 |
+
from torch import nn
|
4 |
+
import torch
|
5 |
+
from torchvision.transforms import functional as F
|
6 |
+
from torchvision.utils import make_grid
|
7 |
+
from streamlit_ace import st_ace
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
IS_LOCAL = False #Change this
|
11 |
+
|
12 |
+
@st.cache(suppress_st_warning=True)
|
13 |
+
def set_transform(content):
|
14 |
+
# st.write("set transform")
|
15 |
+
try:
|
16 |
+
transform = eval(content, {"kornia": kornia, "nn": nn}, None)
|
17 |
+
except Exception as e:
|
18 |
+
st.write(f"There was an error: {e}")
|
19 |
+
transform = nn.Sequential()
|
20 |
+
return transform
|
21 |
+
|
22 |
+
st.markdown("# Kornia Augmentations Demo")
|
23 |
+
st.sidebar.markdown(
|
24 |
+
"[Kornia](https://github.com/kornia/kornia) is a *differentiable* computer vision library for PyTorch."
|
25 |
+
)
|
26 |
+
uploaded_file = st.sidebar.file_uploader("Choose a file")
|
27 |
+
if uploaded_file is not None:
|
28 |
+
im = Image.open(uploaded_file)
|
29 |
+
else:
|
30 |
+
im = Image.open("./images/pretty_bird.jpg")
|
31 |
+
scaler = int(im.height / 2)
|
32 |
+
st.sidebar.image(im, caption="Input Image", width=256)
|
33 |
+
image = F.pil_to_tensor(im).float() / 255
|
34 |
+
|
35 |
+
|
36 |
+
# batch size is just for show
|
37 |
+
batch_size = st.sidebar.slider("batch_size", min_value=4, max_value=16,value=8)
|
38 |
+
gpu = st.sidebar.checkbox("Use GPU!", value=True)
|
39 |
+
if not gpu:
|
40 |
+
st.sidebar.markdown("With Kornia you do ops on the GPU!")
|
41 |
+
device = torch.device("cpu")
|
42 |
+
else:
|
43 |
+
if not IS_LOCAL:
|
44 |
+
st.sidebar.markdown("(GPU Not available on hosted demo, try on your local!)")
|
45 |
+
# Credits
|
46 |
+
st.sidebar.caption("Demo made by [Ceyda Cinarel](https://linktr.ee/ceydai)")
|
47 |
+
st.sidebar.markdown("Clone [Code](https://github.com/cceyda/kornia-demo)")
|
48 |
+
device = torch.device("cpu")
|
49 |
+
else:
|
50 |
+
st.sidebar.markdown("Running on GPU~")
|
51 |
+
device = torch.device("cuda:0")
|
52 |
+
|
53 |
+
predefined_transforms = [
|
54 |
+
"""
|
55 |
+
nn.Sequential(
|
56 |
+
kornia.augmentation.RandomAffine(degrees=360,p=0.5),
|
57 |
+
kornia.augmentation.ColorJitter(brightness=0.2, contrast=0.3, saturation=0.2, hue=0.3, p=1)
|
58 |
+
)
|
59 |
+
# p=0.5 is the probability of applying the transformation
|
60 |
+
""",
|
61 |
+
"""
|
62 |
+
nn.Sequential(
|
63 |
+
kornia.augmentation.RandomErasing(scale=(.4, .8), ratio=(.3, 1/.3), p=0.5),
|
64 |
+
)
|
65 |
+
""",
|
66 |
+
"""
|
67 |
+
nn.Sequential(
|
68 |
+
kornia.augmentation.RandomErasing(scale=(.4, .8), ratio=(.3, 1/.3), p=1, same_on_batch=True),
|
69 |
+
)
|
70 |
+
#By setting same_on_batch=True you can apply the same transform across the batch
|
71 |
+
""",
|
72 |
+
f"""
|
73 |
+
nn.Sequential(
|
74 |
+
kornia.augmentation.RandomResizedCrop(size=({scaler}, {scaler}), scale=(3., 3.), ratio=(2., 2.), p=1.),
|
75 |
+
kornia.augmentation.RandomHorizontalFlip(p=0.7),
|
76 |
+
kornia.augmentation.RandomGrayscale(p=0.5),
|
77 |
+
)
|
78 |
+
""",
|
79 |
+
]
|
80 |
+
|
81 |
+
selected_transform = st.selectbox(
|
82 |
+
"Pick an augmentation pipeline example:", predefined_transforms
|
83 |
+
)
|
84 |
+
|
85 |
+
st.write("Transform to apply:")
|
86 |
+
readonly = False
|
87 |
+
content = st_ace(
|
88 |
+
value=selected_transform,
|
89 |
+
height=150,
|
90 |
+
language="python",
|
91 |
+
keybinding="vscode",
|
92 |
+
show_gutter=True,
|
93 |
+
show_print_margin=True,
|
94 |
+
wrap=False,
|
95 |
+
auto_update=False,
|
96 |
+
readonly=readonly,
|
97 |
+
)
|
98 |
+
if content:
|
99 |
+
# st.write(content)
|
100 |
+
transform = set_transform(content)
|
101 |
+
|
102 |
+
# st.write(transform)
|
103 |
+
|
104 |
+
# with st.echo():
|
105 |
+
# transform = nn.Sequential(
|
106 |
+
# K.RandomAffine(360),
|
107 |
+
# K.ColorJitter(0.2, 0.3, 0.2, 0.3)
|
108 |
+
# )
|
109 |
+
|
110 |
+
process = st.button("Next Batch")
|
111 |
+
|
112 |
+
# Fake dataloader
|
113 |
+
image_batch = torch.stack(batch_size * [image])
|
114 |
+
|
115 |
+
|
116 |
+
image_batch.to(device)
|
117 |
+
transformeds = None
|
118 |
+
try:
|
119 |
+
transformeds = transform(image_batch)
|
120 |
+
except Exception as e:
|
121 |
+
st.write(f"There was an error: {e}")
|
122 |
+
|
123 |
+
|
124 |
+
|
125 |
+
|
126 |
+
cols = st.columns(4)
|
127 |
+
|
128 |
+
# st.image(F.to_pil_image(make_grid(transformeds)))
|
129 |
+
if transformeds is not None:
|
130 |
+
for i, x in enumerate(transformeds):
|
131 |
+
i = i % 4
|
132 |
+
cols[i].image(F.to_pil_image(x), use_column_width=True)
|
133 |
+
|
134 |
+
st.markdown(
|
135 |
+
"There are a lot more transformations available: [Documentation](https://kornia.readthedocs.io/en/latest/augmentation.module.html)"
|
136 |
+
)
|
137 |
+
st.markdown(
|
138 |
+
"Kornia can do a lot more than augmentations~ [Check it out](https://kornia.readthedocs.io/en/latest/introduction.html#highlighted-features)"
|
139 |
+
)
|
140 |
+
# if process:
|
141 |
+
# pass
|
142 |
+
|