hocherie commited on
Commit
bb018e6
1 Parent(s): e04391f

added more examples

Browse files
app.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from matplotlib import pyplot as plt
3
+ from mapper.utils.io import read_image
4
+ from mapper.utils.exif import EXIF
5
+ from mapper.utils.wrappers import Camera
6
+ from mapper.data.image import rectify_image, pad_image, resize_image
7
+ from mapper.utils.viz_2d import one_hot_argmax_to_rgb, plot_images
8
+ from mapper.module import GenericModule
9
+ from perspective2d import PerspectiveFields
10
+ import torch
11
+ import numpy as np
12
+ from typing import Optional, Tuple
13
+ from omegaconf import OmegaConf
14
+
15
+ description = """
16
+ <h1 align="center">
17
+ <ins>MapItAnywhere (MIA) </ins>
18
+ <br>
19
+ Empowering Bird’s Eye View Mapping using Large-scale Public Data
20
+ <br>
21
+ <h3 align="center">
22
+ <a href="https://mapitanywhere.github.io" target="_blank">Project Page</a> |
23
+ <a href="https://arxiv.org/abs/2109.08203" target="_blank">Paper</a> |
24
+ <a href="https://github.com/MapItAnywhere/MapItAnywhere" target="_blank">Code</a>
25
+ </h3>
26
+ <p align="center">
27
+ Mapper generates birds-eye-view maps from in-the-wild monocular first-person view images. You can try our demo by uploading your images or using the examples provided. Tip: You can also try out images across the world using <a href="https://www.mapillary.com/app" target="_blank">Mapillary</a> &#128521;
28
+ </p>
29
+ """
30
+
31
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
+
33
+ cfg = OmegaConf.load("config.yaml")
34
+
35
+ class ImageCalibrator(PerspectiveFields):
36
+ def __init__(self, version: str = "Paramnet-360Cities-edina-centered"):
37
+ super().__init__(version)
38
+ self.eval()
39
+
40
+ def run(
41
+ self,
42
+ image_rgb: np.ndarray,
43
+ focal_length: Optional[float] = None,
44
+ exif: Optional[EXIF] = None,
45
+ ) -> Tuple[Tuple[float, float], Camera]:
46
+ h, w, *_ = image_rgb.shape
47
+ if focal_length is None and exif is not None:
48
+ _, focal_ratio = exif.extract_focal()
49
+ if focal_ratio != 0:
50
+ focal_length = focal_ratio * max(h, w)
51
+ calib = self.inference(img_bgr=image_rgb[..., ::-1])
52
+ roll_pitch = (calib["pred_roll"].item(), calib["pred_pitch"].item())
53
+ if focal_length is None:
54
+ vfov = calib["pred_vfov"].item()
55
+ focal_length = h / 2 / np.tan(np.deg2rad(vfov) / 2)
56
+
57
+ camera = Camera.from_dict(
58
+ {
59
+ "model": "SIMPLE_PINHOLE",
60
+ "width": w,
61
+ "height": h,
62
+ "params": [focal_length, w / 2 + 0.5, h / 2 + 0.5],
63
+ }
64
+ )
65
+ return roll_pitch, camera
66
+
67
+ def preprocess_pipeline(image, roll_pitch, camera):
68
+ image = torch.from_numpy(image).float() / 255
69
+ image = image.permute(2, 0, 1).to(device)
70
+ camera = camera.to(device)
71
+
72
+ image, valid = rectify_image(image, camera.float(), -roll_pitch[0], -roll_pitch[1])
73
+
74
+ roll_pitch *= 0
75
+
76
+ image, _, camera, valid = resize_image(
77
+ image=image,
78
+ size=512,
79
+ camera=camera,
80
+ fn=max,
81
+ valid=valid
82
+ )
83
+
84
+ image, valid, camera = pad_image(
85
+ image, 512, camera, valid
86
+ )
87
+
88
+ camera = torch.stack([camera])
89
+
90
+ return {
91
+ "image": image.unsqueeze(0).to(device),
92
+ "valid": valid.unsqueeze(0).to(device),
93
+ "camera": camera.float().to(device),
94
+ }
95
+
96
+
97
+ calibrator = ImageCalibrator().to(device)
98
+ model = GenericModule(cfg)
99
+ model = model.load_from_checkpoint("trained_weights/mapper-excl-ood.ckpt", strict=False, cfg=cfg)
100
+ model = model.to(device)
101
+ model = model.eval()
102
+
103
+ def run(input_img):
104
+ image_path = input_img.name
105
+
106
+ image = read_image(image_path)
107
+ with open(image_path, "rb") as fid:
108
+ exif = EXIF(fid, lambda: image.shape[:2])
109
+
110
+ gravity, camera = calibrator.run(image, exif=exif)
111
+
112
+ data = preprocess_pipeline(image, gravity, camera)
113
+ res = model(data)
114
+
115
+ prediction = res['output']
116
+ rgb_prediction = one_hot_argmax_to_rgb(prediction, 6).squeeze(0).permute(1, 2, 0).cpu().long().numpy()
117
+ valid = res['valid_bev'].squeeze(0)[..., :-1]
118
+ rgb_prediction[~valid.cpu().numpy()] = 255
119
+
120
+ # TODO: add legend here
121
+
122
+ plot_images([image, rgb_prediction], titles=["Input Image", "Prediction"], pad=2, adaptive=True)
123
+
124
+ return plt.gcf()
125
+
126
+
127
+ examples = [
128
+ ["examples/left_crossing.jpg"],
129
+ ["examples/crossing.jpg"]
130
+ ["examples/two_roads.jpg"],
131
+ ["examples/night_road.jpg"],
132
+ ["examples/night_crossing.jpg"],
133
+ ]
134
+
135
+ demo = gr.Interface(
136
+ fn=run,
137
+ inputs=[
138
+ gr.File(file_types=["image"], label="Input Image")
139
+ ],
140
+ outputs=[
141
+ gr.Plot(label="Prediction", format="png"),
142
+ ],
143
+ description=description,
144
+ examples=examples)
145
+ demo.launch(share=False, server_name="0.0.0.0")
examples/crossing.jpg ADDED

Git LFS Details

  • SHA256: 0bfaa2c0946d80b19513a90a253478efddb887ac49bcf48979c709af3caddc49
  • Pointer size: 130 Bytes
  • Size of remote file: 87.1 kB
examples/left_crossing.jpg ADDED

Git LFS Details

  • SHA256: 09a38947833bbec9487090f5046eb28c471b59bea2f5ed48bb571de397fbaa04
  • Pointer size: 131 Bytes
  • Size of remote file: 110 kB
examples/night_crossing.jpg ADDED

Git LFS Details

  • SHA256: 08c2a7abb103a83a099599e8da4126a956784854412da92e81e3d791a55e034b
  • Pointer size: 130 Bytes
  • Size of remote file: 49.6 kB
examples/night_road.jpg ADDED

Git LFS Details

  • SHA256: d097121e96cc190076ce4cb1ad0f1318370d061733e0e26f5677a6df576d2af3
  • Pointer size: 130 Bytes
  • Size of remote file: 43 kB
examples/two_roads.jpg ADDED

Git LFS Details

  • SHA256: 72e0486eb388690c5dd91ab0a1ae89b683affcad7a6a4a0b595d611c03676260
  • Pointer size: 130 Bytes
  • Size of remote file: 88.5 kB