Spaces:
Runtime error
Runtime error
NimaBoscarino
commited on
Commit
·
4756ce1
1
Parent(s):
4b98f56
Refactor, polish (WIP)
Browse files- app.py +16 -26
- inferences.py +107 -0
app.py
CHANGED
@@ -1,41 +1,27 @@
|
|
1 |
import os
|
2 |
import gradio as gr
|
3 |
import googlemaps
|
4 |
-
import
|
5 |
-
from PIL import Image
|
6 |
|
7 |
os.system('gdown https://drive.google.com/u/0/uc?id=18OCUIy7JQ2Ow_-cC5xn_hhDn-Bp45N1K')
|
8 |
os.system('unzip release-github-v1.zip')
|
9 |
os.system('mkdir config')
|
10 |
os.system('mv model config')
|
11 |
|
|
|
|
|
12 |
API_KEY = os.environ.get("API_KEY")
|
13 |
gmaps = googlemaps.Client(key=API_KEY)
|
|
|
14 |
|
15 |
def predict(place):
|
16 |
-
# I don't think I need to do any error handling, I'll just let Gradio manage it. (Test it with a bad place though)
|
17 |
geocode_result = gmaps.geocode(place)
|
18 |
loc = geocode_result[0]['geometry']['location']
|
19 |
static_map_url = f"https://maps.googleapis.com/maps/api/streetview?size=400x400&location={loc['lat']},{loc['lng']}&fov=80&heading=70&pitch=0&key={API_KEY}"
|
20 |
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
file = open("./inputs/loc_image.png", "wb")
|
25 |
-
file.write(img)
|
26 |
-
file.close()
|
27 |
-
|
28 |
-
# Next improvement is to refactor apply_events so that I can load the model in memory when the interface launces
|
29 |
-
# Would save ~10 seconds according to the logs.
|
30 |
-
# I would also like to not have to write the image to the filesystem, and instead just hold it in memory. As it is right now it would mess up with more than one person using it.
|
31 |
-
os.system('python apply_events.py -b 1 -i ./inputs -r config/model/masker --output_path ./outputs --overwrite')
|
32 |
-
|
33 |
-
# Also if I refactor then I wouldn't have to read from the filesystem...
|
34 |
-
# I would like to show all three images
|
35 |
-
os.system('ls -R ./outputs')
|
36 |
-
|
37 |
-
out_img = Image.open("./outputs/loc_image_flood_640.png")
|
38 |
-
return out_img # I actually want to return all 3 images.
|
39 |
|
40 |
|
41 |
gr.Interface(
|
@@ -43,10 +29,14 @@ gr.Interface(
|
|
43 |
inputs=[
|
44 |
gr.inputs.Textbox(label="Address or place name")
|
45 |
],
|
46 |
-
outputs="image",
|
47 |
title="ClimateGAN",
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
|
|
|
|
|
|
|
|
52 |
).launch()
|
|
|
1 |
import os
|
2 |
import gradio as gr
|
3 |
import googlemaps
|
4 |
+
from skimage import io
|
|
|
5 |
|
6 |
os.system('gdown https://drive.google.com/u/0/uc?id=18OCUIy7JQ2Ow_-cC5xn_hhDn-Bp45N1K')
|
7 |
os.system('unzip release-github-v1.zip')
|
8 |
os.system('mkdir config')
|
9 |
os.system('mv model config')
|
10 |
|
11 |
+
from inferences import ClimateGAN
|
12 |
+
|
13 |
API_KEY = os.environ.get("API_KEY")
|
14 |
gmaps = googlemaps.Client(key=API_KEY)
|
15 |
+
model = ClimateGAN(model_path="config/model/masker")
|
16 |
|
17 |
def predict(place):
|
|
|
18 |
geocode_result = gmaps.geocode(place)
|
19 |
loc = geocode_result[0]['geometry']['location']
|
20 |
static_map_url = f"https://maps.googleapis.com/maps/api/streetview?size=400x400&location={loc['lat']},{loc['lng']}&fov=80&heading=70&pitch=0&key={API_KEY}"
|
21 |
|
22 |
+
img_np = io.imread(static_map_url)
|
23 |
+
flood, wildfire, smog = model.inference(img_np)
|
24 |
+
return img_np, flood, wildfire, smog
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
|
27 |
gr.Interface(
|
|
|
29 |
inputs=[
|
30 |
gr.inputs.Textbox(label="Address or place name")
|
31 |
],
|
32 |
+
outputs=["image", "image", "image", "image"],
|
33 |
title="ClimateGAN",
|
34 |
+
description="Enter an address or place name, and ClimateGAN will generate images showing how the location could be impacted by flooding, wildfires, or smog.",
|
35 |
+
article="<p style='text-align: center'>This project is a clone of <a href='https://thisclimatedoesnotexist.com/'>ThisClimateDoesNotExist</a> | <a href='https://github.com/cc-ai/climategan'>ClimateGAN GitHub Repo</a></p>",
|
36 |
+
examples=[
|
37 |
+
"Kafka's Great Northern Way, Vancouver",
|
38 |
+
"Simon Fraser University",
|
39 |
+
"Duomo, Milano"
|
40 |
+
],
|
41 |
+
css=".footer{display:none !important}",
|
42 |
).launch()
|
inferences.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from skimage.color import rgba2rgb
|
3 |
+
from skimage.transform import resize
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from climategan.trainer import Trainer
|
7 |
+
|
8 |
+
|
9 |
+
def uint8(array):
|
10 |
+
"""
|
11 |
+
convert an array to np.uint8 (does not rescale or anything else than changing dtype)
|
12 |
+
|
13 |
+
Args:
|
14 |
+
array (np.array): array to modify
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
np.array(np.uint8): converted array
|
18 |
+
"""
|
19 |
+
return array.astype(np.uint8)
|
20 |
+
|
21 |
+
def resize_and_crop(img, to=640):
|
22 |
+
"""
|
23 |
+
Resizes an image so that it keeps the aspect ratio and the smallest dimensions
|
24 |
+
is `to`, then crops this resized image in its center so that the output is `to x to`
|
25 |
+
without aspect ratio distortion
|
26 |
+
|
27 |
+
Args:
|
28 |
+
img (np.array): np.uint8 255 image
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
np.array: [0, 1] np.float32 image
|
32 |
+
"""
|
33 |
+
# resize keeping aspect ratio: smallest dim is 640
|
34 |
+
h, w = img.shape[:2]
|
35 |
+
if h < w:
|
36 |
+
size = (to, int(to * w / h))
|
37 |
+
else:
|
38 |
+
size = (int(to * h / w), to)
|
39 |
+
|
40 |
+
r_img = resize(img, size, preserve_range=True, anti_aliasing=True)
|
41 |
+
r_img = uint8(r_img)
|
42 |
+
|
43 |
+
# crop in the center
|
44 |
+
H, W = r_img.shape[:2]
|
45 |
+
|
46 |
+
top = (H - to) // 2
|
47 |
+
left = (W - to) // 2
|
48 |
+
|
49 |
+
rc_img = r_img[top : top + to, left : left + to, :]
|
50 |
+
|
51 |
+
return rc_img / 255.0
|
52 |
+
|
53 |
+
def to_m1_p1(img):
|
54 |
+
"""
|
55 |
+
rescales a [0, 1] image to [-1, +1]
|
56 |
+
|
57 |
+
Args:
|
58 |
+
img (np.array): float32 numpy array of an image in [0, 1]
|
59 |
+
i (int): Index of the image being rescaled
|
60 |
+
|
61 |
+
Raises:
|
62 |
+
ValueError: If the image is not in [0, 1]
|
63 |
+
|
64 |
+
Returns:
|
65 |
+
np.array(np.float32): array in [-1, +1]
|
66 |
+
"""
|
67 |
+
if img.min() >= 0 and img.max() <= 1:
|
68 |
+
return (img.astype(np.float32) - 0.5) * 2
|
69 |
+
raise ValueError(f"Data range mismatch for image: ({img.min()}, {img.max()})")
|
70 |
+
|
71 |
+
# No need to do any timing in this, since it's just for the HF Space
|
72 |
+
class ClimateGAN():
|
73 |
+
def __init__(self, model_path) -> None:
|
74 |
+
torch.set_grad_enabled(False)
|
75 |
+
self.target_size = 640
|
76 |
+
self.trainer = Trainer.resume_from_path(
|
77 |
+
model_path,
|
78 |
+
setup=True,
|
79 |
+
inference=True,
|
80 |
+
new_exp=None,
|
81 |
+
)
|
82 |
+
|
83 |
+
# Does all three inferences at the moment.
|
84 |
+
def inference(self, orig_image):
|
85 |
+
image, new_size = self._preprocess_image(orig_image)
|
86 |
+
|
87 |
+
image = np.stack(image)
|
88 |
+
# Retreive numpy events as a dict {event: array[BxHxWxC]}
|
89 |
+
outputs = self.trainer.infer_all(
|
90 |
+
image,
|
91 |
+
numpy=True,
|
92 |
+
bin_value=0.5,
|
93 |
+
)
|
94 |
+
|
95 |
+
return outputs['flood'], outputs['wildfire'], outputs['smog']
|
96 |
+
|
97 |
+
def _preprocess_image(self, img):
|
98 |
+
# rgba to rgb
|
99 |
+
data = img if img.shape[-1] == 3 else uint8(rgba2rgb(img) * 255)
|
100 |
+
|
101 |
+
# to args.target_size
|
102 |
+
data = resize_and_crop(data, self.target_size)
|
103 |
+
new_size = (self.target_size, self.target_size)
|
104 |
+
|
105 |
+
# resize() produces [0, 1] images, rescale to [-1, 1]
|
106 |
+
data = to_m1_p1(data)
|
107 |
+
return data, new_size
|