Spaces:
Running on CPU Upgrade

osv5m commited on
Commit
1aae222
ยท
verified ยท
1 Parent(s): eaa008b

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -291
app.py DELETED
@@ -1,291 +0,0 @@
1
- """Requires gradio==3.44.0"""
2
- import io
3
- import os
4
- import uuid
5
- import matplotlib
6
- import time
7
- import matplotlib.style as mplstyle
8
- mplstyle.use(['fast'])
9
- from os.path import join
10
- from PIL import Image
11
- import pandas as pd
12
- import reverse_geocoder as rg
13
- import cartopy.crs as ccrs
14
- import cartopy.feature as cfeature
15
- import matplotlib.pyplot as plt
16
- from math import radians, sin, cos, sqrt, asin, exp
17
- from collections import defaultdict
18
- import wandb
19
- import shutil
20
-
21
-
22
- IMAGE_FOLDER = './select'
23
- CSV_FILE = './select.csv'
24
- RESULTS_DIR = './results'
25
- RULES = """# Plonk ๐ŸŒ ๐ŸŒŽ ๐ŸŒ
26
- ## Total: 50 pictures
27
- ## Estimated time: 15min
28
- ### How it works:
29
- - Click on the map ๐Ÿ—บ๏ธ (left) to the location at which you think the image ๐Ÿ–ผ๏ธ (right) was captured!
30
- - Click next to move to the next image.
31
- โš ๏ธ Your selection is final!
32
- ### Click "start" to begin...
33
- """
34
-
35
- def haversine(lat1, lon1, lat2, lon2):
36
- if (lat1 is None) or (lon1 is None) or (lat2 is None) or (lon2 is None):
37
- return 0
38
- R = 6371 # radius of the earth in km
39
- dLat = radians(lat2 - lat1)
40
- dLon = radians(lon2 - lon1)
41
- a = (
42
- sin(dLat / 2.0) ** 2
43
- + cos(radians(lat1)) * cos(radians(lat2)) * sin(dLon / 2.0) ** 2
44
- )
45
- c = 2 * asin(sqrt(a))
46
- distance = R * c
47
- return distance
48
-
49
- def geoscore(d):
50
- return 5000 * exp(-d / 1492.7)
51
-
52
-
53
- class Engine(object):
54
- def __init__(self, image_folder, csv_file, cache_path):
55
- self.image_folder = image_folder
56
- self.load_images_and_coordinates(csv_file)
57
- self.cache_path = cache_path
58
-
59
- # Initialize the score and distance lists
60
- self.index = 0
61
- self.stats = defaultdict(list)
62
-
63
- # Create the figure and canvas only once
64
- self.fig = plt.Figure(figsize=(10, 6))
65
- self.ax = self.fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())
66
- self.MIN_LON, self.MAX_LON, self.MIN_LAT, self.MAX_LAT = self.ax.get_extent()
67
-
68
- def load_images_and_coordinates(self, csv_file):
69
- # Load the CSV
70
- df = pd.read_csv(csv_file)
71
-
72
- # Get the image filenames and their coordinates
73
- self.images = df['id'].tolist()[:]
74
- self.coordinates = df[['longitude', 'latitude']].values.tolist()[:]
75
- self.admins = df[['city', 'sub-region', 'region', 'country']].values.tolist()[:]
76
- self.preds = df[['pred_longitude', 'pred_latitude']].values.tolist()[:]
77
-
78
- def isfinal(self):
79
- return self.index == len(self.images)-1
80
-
81
- def load_image(self):
82
- if self.index > len(self.images)-1:
83
- self.master.update_idletasks()
84
- self.finish()
85
-
86
- self.ax.clear()
87
- self.ax.set_global()
88
- self.ax.stock_img()
89
- self.ax.add_feature(cfeature.COASTLINE)
90
- self.ax.add_feature(cfeature.BORDERS, linestyle=':')
91
- self.fig.canvas.draw()
92
- pil = self.get_figure()
93
- self.set_clock()
94
- return pil, os.path.join(self.image_folder, f"{self.images[self.index]}.jpg"), '### ' + str(self.index + 1) + '/' + str(len(self.images))
95
-
96
- def get_figure(self):
97
- img_buf = io.BytesIO()
98
- self.fig.savefig(img_buf, format='png', bbox_inches='tight', pad_inches=0, dpi=300)
99
- pil = Image.open(img_buf)
100
- self.width, self.height = pil.size
101
- return pil
102
-
103
- def normalize_pixels(self, click_lon, click_lat):
104
- return self.MIN_LON + click_lon * (self.MAX_LON-self.MIN_LON) / self.width, self.MIN_LAT + (self.height - click_lat+1) * (self.MAX_LAT-self.MIN_LAT) / self.height
105
-
106
- def set_clock(self):
107
- self.time = time.time()
108
-
109
- def get_clock(self):
110
- return time.time() - self.time
111
-
112
- def click(self, click_lon, click_lat):
113
- time_elapsed = self.get_clock()
114
- self.stats['times'].append(time_elapsed)
115
-
116
- # convert click_lon, click_lat to lat, lon (given that you have the borders of the image)
117
- # click_lon and click_lat is in pixels
118
- # lon and lat is in degrees
119
- click_lon, click_lat = self.normalize_pixels(click_lon, click_lat)
120
- self.stats['clicked_locations'].append((click_lat, click_lon))
121
- true_lon, true_lat = self.coordinates[self.index]
122
- pred_lon, pred_lat = self.preds[self.index]
123
-
124
- self.ax.plot(pred_lon, pred_lat, 'gv', transform=ccrs.Geodetic())
125
- self.ax.plot([true_lon, pred_lon], [true_lat, pred_lat], color='green', linewidth=1, transform=ccrs.Geodetic())
126
- self.ax.plot(click_lon, click_lat, 'bo', transform=ccrs.Geodetic())
127
- self.ax.plot([true_lon, click_lon], [true_lat, click_lat], color='blue', linewidth=1, transform=ccrs.Geodetic())
128
- self.ax.plot(true_lon, true_lat, 'rx', transform=ccrs.Geodetic())
129
-
130
- distance = haversine(true_lat, true_lon, click_lat, click_lon)
131
- score = geoscore(distance)
132
- self.stats['scores'].append(score)
133
- self.stats['distances'].append(distance)
134
-
135
- average_text = self.update_average_display()
136
- result_text = (f"### GeoScore: {score:.0f}, distance: {distance:.0f} km\n ")
137
-
138
- self.cache(self.index+1, score, distance, (click_lat, click_lon), time_elapsed)
139
- return self.get_figure(), result_text + average_text
140
-
141
- def next_image(self):
142
- # Go to the next image
143
- self.index += 1
144
- return self.load_image()
145
-
146
- def update_average_display(self):
147
- # Calculate the average values
148
- avg_score = sum(self.stats['scores']) / len(self.stats['scores']) if self.stats['scores'] else 0
149
- avg_distance = sum(self.stats['distances']) / len(self.stats['distances']) if self.stats['distances'] else 0
150
-
151
- # Update the text box
152
- return f"### Average GeoScore: {avg_score:.0f}, Average distance: {avg_distance:.0f} km"
153
-
154
- def finish(self):
155
- clicks = rg.search(self.stats['clicked_locations'])
156
- clicked_admins = [[click['name'], click['admin2'], click['admin1'], click['cc']] for click in clicks]
157
-
158
- correct = [0,0,0,0]
159
- valid = [0,0,0,0]
160
-
161
- for clicked_admin, true_admin in zip(clicked_admins, self.admins):
162
- for i in range(4):
163
- if true_admin[i]!= 'nan':
164
- valid[i] += 1
165
- if true_admin[i] == clicked_admin[i]:
166
- correct[i] += 1
167
-
168
- avg_city_accuracy = correct[0] / valid[0]
169
- avg_area_accuracy = correct[1] / valid[1]
170
- avg_region_accuracy = correct[2] / valid[2]
171
- avg_country_accuracy = correct[3] / valid[3]
172
-
173
- avg_score = sum(self.stats['scores']) / len(self.stats['scores']) if self.stats['scores'] else 0
174
- avg_distance = sum(self.stats['distances']) / len(self.stats['distances']) if self.stats['distances'] else 0
175
-
176
- final_results = (
177
- f"Average GeoScore: {avg_score:.0f} \n" +
178
- f"Average distance: {avg_distance:.0f} km \n" +
179
- f"Country Acc: {100*avg_country_accuracy:.1f} \n" +
180
- f"Region Acc: {100*avg_region_accuracy:.1f} \n" +
181
- f"Area Acc: {100*avg_area_accuracy:.1f} \n" +
182
- f"City Acc: {100*avg_city_accuracy:.1f}"
183
- )
184
-
185
- self.cache_final(final_results)
186
-
187
- # Update the text box
188
- return f"# Your stats ๐ŸŒ\n" + final_results + f" \n# Thanks for playing โค๏ธ"
189
-
190
- # Function to save the game state
191
- def cache(self, index, score, distance, location, time_elapsed):
192
- if not os.path.exists(self.cache_path):
193
- os.makedirs(self.cache_path)
194
-
195
- with open(join(self.cache_path, str(index).zfill(2) + '.txt'), 'w') as f:
196
- print(f"{score}, {distance}, {location[0]}, {location[1]}, {time_elapsed}", file=f)
197
-
198
- # Function to save the game state
199
- def cache_final(self, final_results):
200
- times = ', '.join(map(str, self.stats['times']))
201
- fname = join(self.cache_path, 'full.txt')
202
- with open(fname, 'w') as f:
203
- print(f"{final_results}" + '\n Times: ' + times, file=f)
204
-
205
- zip_ = self.cache_path.rstrip('/') + '.zip'
206
- archived = shutil.make_archive(self.cache_path.rstrip('/'), 'zip', self.cache_path)
207
- try:
208
- wandb.init(project="plonk")
209
- artifact = wandb.Artifact('results', type='results')
210
- artifact.add_file(zip_)
211
- wandb.log_artifact(artifact)
212
- wandb.finish()
213
- except Exception:
214
- print("Failed to log results to wandb")
215
- pass
216
-
217
- if os.path.isfile(zip_):
218
- os.remove(zip_)
219
-
220
-
221
- if __name__ == "__main__":
222
- # login with the key from secret
223
- wandb.login()
224
- csv_str = os.environ['csv']
225
- with open(CSV_FILE, 'w') as f:
226
- f.write(csv_str)
227
-
228
- import gradio as gr
229
- def click(state, evt: gr.SelectData):
230
- if state['clicked']:
231
- return gr.update(), gr.update()
232
- x, y = evt.index
233
- state['clicked'] = True
234
- image, text = state['engine'].click(x, y)
235
- return gr.update(value=image), gr.update(value=text)
236
-
237
- def next_(state):
238
- if state['clicked']:
239
- if state['engine'].isfinal():
240
- text = state['engine'].finish()
241
- return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=text), gr.update(visible=False)
242
- else:
243
- fig, image, text = state['engine'].next_image()
244
- state['clicked'] = False
245
- return gr.update(value=fig), gr.update(value=image), gr.update(value=text), gr.update(), gr.update()
246
- else:
247
- return gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
248
-
249
- def start(state):
250
- # create a unique random temporary name under CACHE_DIR
251
- # generate random hex and make sure it doesn't exist under CACHE_DIR
252
- while True:
253
- path = str(uuid.uuid4().hex)
254
- name = os.path.join(RESULTS_DIR, path)
255
- if not os.path.exists(name):
256
- break
257
-
258
- state['engine'] = Engine(IMAGE_FOLDER, CSV_FILE, name)
259
- state['clicked'] = False
260
- fig, image, text = state['engine'].load_image()
261
-
262
- return (
263
- gr.update(value=fig, visible=True),
264
- gr.update(value=image, visible=True),
265
- gr.update(value=text, visible=True),
266
- gr.update(visible=True),
267
- gr.update(visible=True),
268
- gr.update(visible=False),
269
- gr.update(visible=False),
270
- gr.update(visible=False),
271
- gr.update(visible=False),
272
- )
273
-
274
- with gr.Blocks() as demo:
275
- state = gr.State({})
276
- rules = gr.Markdown(RULES, visible=True)
277
-
278
- start_button = gr.Button("Start", visible=True)
279
- with gr.Row():
280
- map_ = gr.Image(label='Map', visible=False)
281
- image_ = gr.Image(label='Image', visible=False)
282
- with gr.Row():
283
- text = gr.Markdown("", visible=False)
284
- text_count = gr.Markdown("", visible=False)
285
-
286
- next_button = gr.Button("Next", visible=False)
287
- start_button.click(start, inputs=[state], outputs=[map_, image_, text_count, text, next_button, rules, state, start_button])
288
- map_.select(click, inputs=[state], outputs=[map_, text])
289
- next_button.click(next_, inputs=[state], outputs=[map_, image_, text_count, text, next_button])
290
-
291
- demo.launch()