hocherie commited on
Commit
df37182
1 Parent(s): bb018e6

Revert "added more examples"

Browse files

This reverts commit bb018e685e48b00bdb71b50254c58f294c1cea1a.

app.py DELETED
@@ -1,145 +0,0 @@
1
- import gradio as gr
2
- from matplotlib import pyplot as plt
3
- from mapper.utils.io import read_image
4
- from mapper.utils.exif import EXIF
5
- from mapper.utils.wrappers import Camera
6
- from mapper.data.image import rectify_image, pad_image, resize_image
7
- from mapper.utils.viz_2d import one_hot_argmax_to_rgb, plot_images
8
- from mapper.module import GenericModule
9
- from perspective2d import PerspectiveFields
10
- import torch
11
- import numpy as np
12
- from typing import Optional, Tuple
13
- from omegaconf import OmegaConf
14
-
15
- description = """
16
- <h1 align="center">
17
- <ins>MapItAnywhere (MIA) </ins>
18
- <br>
19
- Empowering Bird’s Eye View Mapping using Large-scale Public Data
20
- <br>
21
- <h3 align="center">
22
- <a href="https://mapitanywhere.github.io" target="_blank">Project Page</a> |
23
- <a href="https://arxiv.org/abs/2109.08203" target="_blank">Paper</a> |
24
- <a href="https://github.com/MapItAnywhere/MapItAnywhere" target="_blank">Code</a>
25
- </h3>
26
- <p align="center">
27
- Mapper generates birds-eye-view maps from in-the-wild monocular first-person view images. You can try our demo by uploading your images or using the examples provided. Tip: You can also try out images across the world using <a href="https://www.mapillary.com/app" target="_blank">Mapillary</a> &#128521;
28
- </p>
29
- """
30
-
31
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
-
33
- cfg = OmegaConf.load("config.yaml")
34
-
35
- class ImageCalibrator(PerspectiveFields):
36
- def __init__(self, version: str = "Paramnet-360Cities-edina-centered"):
37
- super().__init__(version)
38
- self.eval()
39
-
40
- def run(
41
- self,
42
- image_rgb: np.ndarray,
43
- focal_length: Optional[float] = None,
44
- exif: Optional[EXIF] = None,
45
- ) -> Tuple[Tuple[float, float], Camera]:
46
- h, w, *_ = image_rgb.shape
47
- if focal_length is None and exif is not None:
48
- _, focal_ratio = exif.extract_focal()
49
- if focal_ratio != 0:
50
- focal_length = focal_ratio * max(h, w)
51
- calib = self.inference(img_bgr=image_rgb[..., ::-1])
52
- roll_pitch = (calib["pred_roll"].item(), calib["pred_pitch"].item())
53
- if focal_length is None:
54
- vfov = calib["pred_vfov"].item()
55
- focal_length = h / 2 / np.tan(np.deg2rad(vfov) / 2)
56
-
57
- camera = Camera.from_dict(
58
- {
59
- "model": "SIMPLE_PINHOLE",
60
- "width": w,
61
- "height": h,
62
- "params": [focal_length, w / 2 + 0.5, h / 2 + 0.5],
63
- }
64
- )
65
- return roll_pitch, camera
66
-
67
- def preprocess_pipeline(image, roll_pitch, camera):
68
- image = torch.from_numpy(image).float() / 255
69
- image = image.permute(2, 0, 1).to(device)
70
- camera = camera.to(device)
71
-
72
- image, valid = rectify_image(image, camera.float(), -roll_pitch[0], -roll_pitch[1])
73
-
74
- roll_pitch *= 0
75
-
76
- image, _, camera, valid = resize_image(
77
- image=image,
78
- size=512,
79
- camera=camera,
80
- fn=max,
81
- valid=valid
82
- )
83
-
84
- image, valid, camera = pad_image(
85
- image, 512, camera, valid
86
- )
87
-
88
- camera = torch.stack([camera])
89
-
90
- return {
91
- "image": image.unsqueeze(0).to(device),
92
- "valid": valid.unsqueeze(0).to(device),
93
- "camera": camera.float().to(device),
94
- }
95
-
96
-
97
- calibrator = ImageCalibrator().to(device)
98
- model = GenericModule(cfg)
99
- model = model.load_from_checkpoint("trained_weights/mapper-excl-ood.ckpt", strict=False, cfg=cfg)
100
- model = model.to(device)
101
- model = model.eval()
102
-
103
- def run(input_img):
104
- image_path = input_img.name
105
-
106
- image = read_image(image_path)
107
- with open(image_path, "rb") as fid:
108
- exif = EXIF(fid, lambda: image.shape[:2])
109
-
110
- gravity, camera = calibrator.run(image, exif=exif)
111
-
112
- data = preprocess_pipeline(image, gravity, camera)
113
- res = model(data)
114
-
115
- prediction = res['output']
116
- rgb_prediction = one_hot_argmax_to_rgb(prediction, 6).squeeze(0).permute(1, 2, 0).cpu().long().numpy()
117
- valid = res['valid_bev'].squeeze(0)[..., :-1]
118
- rgb_prediction[~valid.cpu().numpy()] = 255
119
-
120
- # TODO: add legend here
121
-
122
- plot_images([image, rgb_prediction], titles=["Input Image", "Prediction"], pad=2, adaptive=True)
123
-
124
- return plt.gcf()
125
-
126
-
127
- examples = [
128
- ["examples/left_crossing.jpg"],
129
- ["examples/crossing.jpg"]
130
- ["examples/two_roads.jpg"],
131
- ["examples/night_road.jpg"],
132
- ["examples/night_crossing.jpg"],
133
- ]
134
-
135
- demo = gr.Interface(
136
- fn=run,
137
- inputs=[
138
- gr.File(file_types=["image"], label="Input Image")
139
- ],
140
- outputs=[
141
- gr.Plot(label="Prediction", format="png"),
142
- ],
143
- description=description,
144
- examples=examples)
145
- demo.launch(share=False, server_name="0.0.0.0")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/crossing.jpg DELETED

Git LFS Details

  • SHA256: 0bfaa2c0946d80b19513a90a253478efddb887ac49bcf48979c709af3caddc49
  • Pointer size: 130 Bytes
  • Size of remote file: 87.1 kB
examples/left_crossing.jpg DELETED

Git LFS Details

  • SHA256: 09a38947833bbec9487090f5046eb28c471b59bea2f5ed48bb571de397fbaa04
  • Pointer size: 131 Bytes
  • Size of remote file: 110 kB
examples/night_crossing.jpg DELETED

Git LFS Details

  • SHA256: 08c2a7abb103a83a099599e8da4126a956784854412da92e81e3d791a55e034b
  • Pointer size: 130 Bytes
  • Size of remote file: 49.6 kB
examples/night_road.jpg DELETED

Git LFS Details

  • SHA256: d097121e96cc190076ce4cb1ad0f1318370d061733e0e26f5677a6df576d2af3
  • Pointer size: 130 Bytes
  • Size of remote file: 43 kB
examples/two_roads.jpg DELETED

Git LFS Details

  • SHA256: 72e0486eb388690c5dd91ab0a1ae89b683affcad7a6a4a0b595d611c03676260
  • Pointer size: 130 Bytes
  • Size of remote file: 88.5 kB