Nguyen Thai Thao Uyen commited on
Commit
586d4f8
·
1 Parent(s): 1617319
Files changed (2) hide show
  1. app.py +59 -202
  2. run.py +37 -0
app.py CHANGED
@@ -1,202 +1,59 @@
1
- import ipyleaflet as L
2
- from faicons import icon_svg
3
- from geopy.distance import geodesic, great_circle
4
- from shared import BASEMAPS, CITIES
5
- from shiny import reactive
6
- from shiny.express import input, render, ui
7
- from shinywidgets import render_widget
8
-
9
- city_names = sorted(list(CITIES.keys()))
10
-
11
- ui.page_opts(title="Location Distance Calculator", fillable=True)
12
- {"class": "bslib-page-dashboard"}
13
-
14
- with ui.sidebar():
15
- ui.input_selectize("loc1", "Location 1", choices=city_names, selected="New York")
16
- ui.input_selectize("loc2", "Location 2", choices=city_names, selected="London")
17
- ui.input_selectize(
18
- "basemap",
19
- "Choose a basemap",
20
- choices=list(BASEMAPS.keys()),
21
- selected="WorldImagery",
22
- )
23
- ui.input_dark_mode(mode="dark")
24
-
25
- with ui.layout_column_wrap(fill=False):
26
- with ui.value_box(showcase=icon_svg("globe"), theme="gradient-blue-indigo"):
27
- "Great Circle Distance"
28
-
29
- @render.text
30
- def great_circle_dist():
31
- circle = great_circle(loc1xy(), loc2xy())
32
- return f"{circle.kilometers.__round__(1)} km"
33
-
34
- with ui.value_box(showcase=icon_svg("ruler"), theme="gradient-blue-indigo"):
35
- "Geodisic Distance"
36
-
37
- @render.text
38
- def geo_dist():
39
- dist = geodesic(loc1xy(), loc2xy())
40
- return f"{dist.kilometers.__round__(1)} km"
41
-
42
- with ui.value_box(showcase=icon_svg("mountain"), theme="gradient-blue-indigo"):
43
- "Altitude Difference"
44
-
45
- @render.text
46
- def altitude():
47
- try:
48
- return f'{loc1()["altitude"] - loc2()["altitude"]} m'
49
- except TypeError:
50
- return "N/A (altitude lookup failed)"
51
-
52
-
53
- with ui.card():
54
- ui.card_header("Map (drag the markers to change locations)")
55
-
56
- @render_widget
57
- def map():
58
- return L.Map(zoom=4, center=(0, 0))
59
-
60
-
61
- # Reactive values to store location information
62
- loc1 = reactive.value()
63
- loc2 = reactive.value()
64
-
65
-
66
- # Update the reactive values when the selectize inputs change
67
- @reactive.effect
68
- def _():
69
- loc1.set(CITIES.get(input.loc1(), loc_str_to_coords(input.loc1())))
70
- loc2.set(CITIES.get(input.loc2(), loc_str_to_coords(input.loc2())))
71
-
72
-
73
- # When a marker is moved, the input value gets updated to "lat, lon",
74
- # so we decode that into a dict (and also look up the altitude)
75
- def loc_str_to_coords(x: str) -> dict:
76
- latlon = x.split(", ")
77
- if len(latlon) != 2:
78
- return {}
79
-
80
- lat = float(latlon[0])
81
- lon = float(latlon[1])
82
-
83
- try:
84
- import requests
85
-
86
- query = f"https://api.open-elevation.com/api/v1/lookup?locations={lat},{lon}"
87
- r = requests.get(query).json()
88
- altitude = r["results"][0]["elevation"]
89
- except Exception:
90
- altitude = None
91
-
92
- return {"latitude": lat, "longitude": lon, "altitude": altitude}
93
-
94
-
95
- # Convenient way to get the lat/lons as a tuple
96
- @reactive.calc
97
- def loc1xy():
98
- return loc1()["latitude"], loc1()["longitude"]
99
-
100
-
101
- @reactive.calc
102
- def loc2xy():
103
- return loc2()["latitude"], loc2()["longitude"]
104
-
105
-
106
- # Add marker for first location
107
- @reactive.effect
108
- def _():
109
- update_marker(map.widget, loc1xy(), on_move1, "loc1")
110
-
111
-
112
- # Add marker for second location
113
- @reactive.effect
114
- def _():
115
- update_marker(map.widget, loc2xy(), on_move2, "loc2")
116
-
117
-
118
- # Add line and fit bounds when either marker is moved
119
- @reactive.effect
120
- def _():
121
- update_line(map.widget, loc1xy(), loc2xy())
122
-
123
-
124
- # If new bounds fall outside of the current view, fit the bounds
125
- @reactive.effect
126
- def _():
127
- l1 = loc1xy()
128
- l2 = loc2xy()
129
-
130
- lat_rng = [min(l1[0], l2[0]), max(l1[0], l2[0])]
131
- lon_rng = [min(l1[1], l2[1]), max(l1[1], l2[1])]
132
- new_bounds = [
133
- [lat_rng[0], lon_rng[0]],
134
- [lat_rng[1], lon_rng[1]],
135
- ]
136
-
137
- b = map.widget.bounds
138
- if len(b) == 0:
139
- map.widget.fit_bounds(new_bounds)
140
- elif (
141
- lat_rng[0] < b[0][0]
142
- or lat_rng[1] > b[1][0]
143
- or lon_rng[0] < b[0][1]
144
- or lon_rng[1] > b[1][1]
145
- ):
146
- map.widget.fit_bounds(new_bounds)
147
-
148
-
149
- # Update the basemap
150
- @reactive.effect
151
- def _():
152
- update_basemap(map.widget, input.basemap())
153
-
154
-
155
- # ---------------------------------------------------------------
156
- # Helper functions
157
- # ---------------------------------------------------------------
158
-
159
-
160
- def update_marker(map: L.Map, loc: tuple, on_move: object, name: str):
161
- remove_layer(map, name)
162
- m = L.Marker(location=loc, draggable=True, name=name)
163
- m.on_move(on_move)
164
- map.add_layer(m)
165
-
166
-
167
- def update_line(map: L.Map, loc1: tuple, loc2: tuple):
168
- remove_layer(map, "line")
169
- map.add_layer(
170
- L.Polyline(locations=[loc1, loc2], color="blue", weight=2, name="line")
171
- )
172
-
173
-
174
- def update_basemap(map: L.Map, basemap: str):
175
- for layer in map.layers:
176
- if isinstance(layer, L.TileLayer):
177
- map.remove_layer(layer)
178
- map.add_layer(L.basemap_to_tiles(BASEMAPS[input.basemap()]))
179
-
180
-
181
- def remove_layer(map: L.Map, name: str):
182
- for layer in map.layers:
183
- if layer.name == name:
184
- map.remove_layer(layer)
185
-
186
-
187
- def on_move1(**kwargs):
188
- return on_move("loc1", **kwargs)
189
-
190
-
191
- def on_move2(**kwargs):
192
- return on_move("loc2", **kwargs)
193
-
194
-
195
- # When the markers are moved, update the selectize inputs to include the new
196
- # location (which results in the locations() reactive value getting updated,
197
- # which invalidates any downstream reactivity that depends on it)
198
- def on_move(id, **kwargs):
199
- loc = kwargs["location"]
200
- loc_str = f"{loc[0]}, {loc[1]}"
201
- choices = city_names + [loc_str]
202
- ui.update_selectize(id, selected=loc_str, choices=choices)
 
1
+ from pathlib import Path
2
+ from typing import List, Dict, Tuple
3
+ import matplotlib.colors as mpl_colors
4
+ import pandas as pd
5
+ import seaborn as sns
6
+ import shinyswatch
7
+ import run
8
+
9
+ from shiny import App, Inputs, Outputs, Session, reactive, render, req, ui
10
+ from transformers import SamModel, SamConfig, SamProcessor
11
+ import torch
12
+
13
+ sns.set_theme()
14
+
15
+ www_dir = Path(__file__).parent.resolve() / "www"
16
+
17
+ app_ui = ui.page_fillable(
18
+ shinyswatch.theme.minty(),
19
+ ui.layout_sidebar(
20
+ ui.sidebar(
21
+ ui.input_file("image_input", "Upload image: ", multiple=True),
22
+ ),
23
+ ui.output_image("image"),
24
+ ui.output_image("image_output"),
25
+ ui.output_image("single_patch_prediction"),
26
+ ui.output_image("single_patch_prob")
27
+ ),
28
+ )
29
+
30
+
31
+ def server(input: Inputs, output: Outputs, session: Session):
32
+ @output
33
+ @render.image
34
+ def image():
35
+ here = Path(__file__).parent
36
+ if input.image_input():
37
+ src = input.image_input()[0]['datapath']
38
+ img = {"src": src, "width": "500px"}
39
+ return img
40
+ return None
41
+
42
+ @output
43
+ @render.image
44
+ def image_output():
45
+ here = Path(__file__).parent
46
+ if input.image_input():
47
+ src = input.image_input()[0]['datapath']
48
+ img = {"src": src, "width": "500px"}
49
+ x = run.pred(src)
50
+ print(x)
51
+ return img
52
+ return None
53
+
54
+
55
+ app = App(
56
+ app_ui,
57
+ server,
58
+ static_assets=str(www_dir),
59
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
run.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import SamModel, SamConfig, SamProcessor
2
+ import torch
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ import app
6
+ import os
7
+
8
+ def pred(src):
9
+ # os.environ['HUGGINGFACE_HUB_HOME'] = './.cache'
10
+ # Load the model configuration
11
+ cache_dir = "/code/cache"
12
+ model_config = SamConfig.from_pretrained("facebook/sam-vit-base",
13
+ cache_dir=cache_dir)
14
+ processor = SamProcessor.from_pretrained("facebook/sam-vit-base",
15
+ cache_dir=cache_dir)
16
+
17
+ # Create an instance of the model architecture with the loaded configuration
18
+ model = SamModel(config=model_config)
19
+ #Update the model by loading the weights from saved file.
20
+ model.load_state_dict(torch.load("sam_model.pth",
21
+ map_location=torch.device('cpu')))
22
+
23
+ new_image = np.array(Image.open(src))
24
+ inputs = processor(new_image, return_tensors="pt")
25
+ inputs = {k: v.to(device) for k, v in inputs.items()}
26
+ x = 1
27
+ # model.eval()
28
+ # # forward pass
29
+ # with torch.no_grad():
30
+ # outputs = model(**inputs, multimask_output=False)
31
+
32
+ # # apply sigmoid
33
+ # single_patch_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
34
+ # # convert soft mask to hard mask
35
+ # single_patch_prob = single_patch_prob.cpu().numpy().squeeze()
36
+ # single_patch_prediction = (single_patch_prob > 0.5).astype(np.uint8)
37
+ return x