Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Delete app.py
Browse files
app.py
DELETED
@@ -1,291 +0,0 @@
|
|
1 |
-
"""Requires gradio==3.44.0"""
|
2 |
-
import io
|
3 |
-
import os
|
4 |
-
import uuid
|
5 |
-
import matplotlib
|
6 |
-
import time
|
7 |
-
import matplotlib.style as mplstyle
|
8 |
-
mplstyle.use(['fast'])
|
9 |
-
from os.path import join
|
10 |
-
from PIL import Image
|
11 |
-
import pandas as pd
|
12 |
-
import reverse_geocoder as rg
|
13 |
-
import cartopy.crs as ccrs
|
14 |
-
import cartopy.feature as cfeature
|
15 |
-
import matplotlib.pyplot as plt
|
16 |
-
from math import radians, sin, cos, sqrt, asin, exp
|
17 |
-
from collections import defaultdict
|
18 |
-
import wandb
|
19 |
-
import shutil
|
20 |
-
|
21 |
-
|
22 |
-
IMAGE_FOLDER = './select'
|
23 |
-
CSV_FILE = './select.csv'
|
24 |
-
RESULTS_DIR = './results'
|
25 |
-
RULES = """# Plonk ๐ ๐ ๐
|
26 |
-
## Total: 50 pictures
|
27 |
-
## Estimated time: 15min
|
28 |
-
### How it works:
|
29 |
-
- Click on the map ๐บ๏ธ (left) to the location at which you think the image ๐ผ๏ธ (right) was captured!
|
30 |
-
- Click next to move to the next image.
|
31 |
-
โ ๏ธ Your selection is final!
|
32 |
-
### Click "start" to begin...
|
33 |
-
"""
|
34 |
-
|
35 |
-
def haversine(lat1, lon1, lat2, lon2):
|
36 |
-
if (lat1 is None) or (lon1 is None) or (lat2 is None) or (lon2 is None):
|
37 |
-
return 0
|
38 |
-
R = 6371 # radius of the earth in km
|
39 |
-
dLat = radians(lat2 - lat1)
|
40 |
-
dLon = radians(lon2 - lon1)
|
41 |
-
a = (
|
42 |
-
sin(dLat / 2.0) ** 2
|
43 |
-
+ cos(radians(lat1)) * cos(radians(lat2)) * sin(dLon / 2.0) ** 2
|
44 |
-
)
|
45 |
-
c = 2 * asin(sqrt(a))
|
46 |
-
distance = R * c
|
47 |
-
return distance
|
48 |
-
|
49 |
-
def geoscore(d):
|
50 |
-
return 5000 * exp(-d / 1492.7)
|
51 |
-
|
52 |
-
|
53 |
-
class Engine(object):
|
54 |
-
def __init__(self, image_folder, csv_file, cache_path):
|
55 |
-
self.image_folder = image_folder
|
56 |
-
self.load_images_and_coordinates(csv_file)
|
57 |
-
self.cache_path = cache_path
|
58 |
-
|
59 |
-
# Initialize the score and distance lists
|
60 |
-
self.index = 0
|
61 |
-
self.stats = defaultdict(list)
|
62 |
-
|
63 |
-
# Create the figure and canvas only once
|
64 |
-
self.fig = plt.Figure(figsize=(10, 6))
|
65 |
-
self.ax = self.fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())
|
66 |
-
self.MIN_LON, self.MAX_LON, self.MIN_LAT, self.MAX_LAT = self.ax.get_extent()
|
67 |
-
|
68 |
-
def load_images_and_coordinates(self, csv_file):
|
69 |
-
# Load the CSV
|
70 |
-
df = pd.read_csv(csv_file)
|
71 |
-
|
72 |
-
# Get the image filenames and their coordinates
|
73 |
-
self.images = df['id'].tolist()[:]
|
74 |
-
self.coordinates = df[['longitude', 'latitude']].values.tolist()[:]
|
75 |
-
self.admins = df[['city', 'sub-region', 'region', 'country']].values.tolist()[:]
|
76 |
-
self.preds = df[['pred_longitude', 'pred_latitude']].values.tolist()[:]
|
77 |
-
|
78 |
-
def isfinal(self):
|
79 |
-
return self.index == len(self.images)-1
|
80 |
-
|
81 |
-
def load_image(self):
|
82 |
-
if self.index > len(self.images)-1:
|
83 |
-
self.master.update_idletasks()
|
84 |
-
self.finish()
|
85 |
-
|
86 |
-
self.ax.clear()
|
87 |
-
self.ax.set_global()
|
88 |
-
self.ax.stock_img()
|
89 |
-
self.ax.add_feature(cfeature.COASTLINE)
|
90 |
-
self.ax.add_feature(cfeature.BORDERS, linestyle=':')
|
91 |
-
self.fig.canvas.draw()
|
92 |
-
pil = self.get_figure()
|
93 |
-
self.set_clock()
|
94 |
-
return pil, os.path.join(self.image_folder, f"{self.images[self.index]}.jpg"), '### ' + str(self.index + 1) + '/' + str(len(self.images))
|
95 |
-
|
96 |
-
def get_figure(self):
|
97 |
-
img_buf = io.BytesIO()
|
98 |
-
self.fig.savefig(img_buf, format='png', bbox_inches='tight', pad_inches=0, dpi=300)
|
99 |
-
pil = Image.open(img_buf)
|
100 |
-
self.width, self.height = pil.size
|
101 |
-
return pil
|
102 |
-
|
103 |
-
def normalize_pixels(self, click_lon, click_lat):
|
104 |
-
return self.MIN_LON + click_lon * (self.MAX_LON-self.MIN_LON) / self.width, self.MIN_LAT + (self.height - click_lat+1) * (self.MAX_LAT-self.MIN_LAT) / self.height
|
105 |
-
|
106 |
-
def set_clock(self):
|
107 |
-
self.time = time.time()
|
108 |
-
|
109 |
-
def get_clock(self):
|
110 |
-
return time.time() - self.time
|
111 |
-
|
112 |
-
def click(self, click_lon, click_lat):
|
113 |
-
time_elapsed = self.get_clock()
|
114 |
-
self.stats['times'].append(time_elapsed)
|
115 |
-
|
116 |
-
# convert click_lon, click_lat to lat, lon (given that you have the borders of the image)
|
117 |
-
# click_lon and click_lat is in pixels
|
118 |
-
# lon and lat is in degrees
|
119 |
-
click_lon, click_lat = self.normalize_pixels(click_lon, click_lat)
|
120 |
-
self.stats['clicked_locations'].append((click_lat, click_lon))
|
121 |
-
true_lon, true_lat = self.coordinates[self.index]
|
122 |
-
pred_lon, pred_lat = self.preds[self.index]
|
123 |
-
|
124 |
-
self.ax.plot(pred_lon, pred_lat, 'gv', transform=ccrs.Geodetic())
|
125 |
-
self.ax.plot([true_lon, pred_lon], [true_lat, pred_lat], color='green', linewidth=1, transform=ccrs.Geodetic())
|
126 |
-
self.ax.plot(click_lon, click_lat, 'bo', transform=ccrs.Geodetic())
|
127 |
-
self.ax.plot([true_lon, click_lon], [true_lat, click_lat], color='blue', linewidth=1, transform=ccrs.Geodetic())
|
128 |
-
self.ax.plot(true_lon, true_lat, 'rx', transform=ccrs.Geodetic())
|
129 |
-
|
130 |
-
distance = haversine(true_lat, true_lon, click_lat, click_lon)
|
131 |
-
score = geoscore(distance)
|
132 |
-
self.stats['scores'].append(score)
|
133 |
-
self.stats['distances'].append(distance)
|
134 |
-
|
135 |
-
average_text = self.update_average_display()
|
136 |
-
result_text = (f"### GeoScore: {score:.0f}, distance: {distance:.0f} km\n ")
|
137 |
-
|
138 |
-
self.cache(self.index+1, score, distance, (click_lat, click_lon), time_elapsed)
|
139 |
-
return self.get_figure(), result_text + average_text
|
140 |
-
|
141 |
-
def next_image(self):
|
142 |
-
# Go to the next image
|
143 |
-
self.index += 1
|
144 |
-
return self.load_image()
|
145 |
-
|
146 |
-
def update_average_display(self):
|
147 |
-
# Calculate the average values
|
148 |
-
avg_score = sum(self.stats['scores']) / len(self.stats['scores']) if self.stats['scores'] else 0
|
149 |
-
avg_distance = sum(self.stats['distances']) / len(self.stats['distances']) if self.stats['distances'] else 0
|
150 |
-
|
151 |
-
# Update the text box
|
152 |
-
return f"### Average GeoScore: {avg_score:.0f}, Average distance: {avg_distance:.0f} km"
|
153 |
-
|
154 |
-
def finish(self):
|
155 |
-
clicks = rg.search(self.stats['clicked_locations'])
|
156 |
-
clicked_admins = [[click['name'], click['admin2'], click['admin1'], click['cc']] for click in clicks]
|
157 |
-
|
158 |
-
correct = [0,0,0,0]
|
159 |
-
valid = [0,0,0,0]
|
160 |
-
|
161 |
-
for clicked_admin, true_admin in zip(clicked_admins, self.admins):
|
162 |
-
for i in range(4):
|
163 |
-
if true_admin[i]!= 'nan':
|
164 |
-
valid[i] += 1
|
165 |
-
if true_admin[i] == clicked_admin[i]:
|
166 |
-
correct[i] += 1
|
167 |
-
|
168 |
-
avg_city_accuracy = correct[0] / valid[0]
|
169 |
-
avg_area_accuracy = correct[1] / valid[1]
|
170 |
-
avg_region_accuracy = correct[2] / valid[2]
|
171 |
-
avg_country_accuracy = correct[3] / valid[3]
|
172 |
-
|
173 |
-
avg_score = sum(self.stats['scores']) / len(self.stats['scores']) if self.stats['scores'] else 0
|
174 |
-
avg_distance = sum(self.stats['distances']) / len(self.stats['distances']) if self.stats['distances'] else 0
|
175 |
-
|
176 |
-
final_results = (
|
177 |
-
f"Average GeoScore: {avg_score:.0f} \n" +
|
178 |
-
f"Average distance: {avg_distance:.0f} km \n" +
|
179 |
-
f"Country Acc: {100*avg_country_accuracy:.1f} \n" +
|
180 |
-
f"Region Acc: {100*avg_region_accuracy:.1f} \n" +
|
181 |
-
f"Area Acc: {100*avg_area_accuracy:.1f} \n" +
|
182 |
-
f"City Acc: {100*avg_city_accuracy:.1f}"
|
183 |
-
)
|
184 |
-
|
185 |
-
self.cache_final(final_results)
|
186 |
-
|
187 |
-
# Update the text box
|
188 |
-
return f"# Your stats ๐\n" + final_results + f" \n# Thanks for playing โค๏ธ"
|
189 |
-
|
190 |
-
# Function to save the game state
|
191 |
-
def cache(self, index, score, distance, location, time_elapsed):
|
192 |
-
if not os.path.exists(self.cache_path):
|
193 |
-
os.makedirs(self.cache_path)
|
194 |
-
|
195 |
-
with open(join(self.cache_path, str(index).zfill(2) + '.txt'), 'w') as f:
|
196 |
-
print(f"{score}, {distance}, {location[0]}, {location[1]}, {time_elapsed}", file=f)
|
197 |
-
|
198 |
-
# Function to save the game state
|
199 |
-
def cache_final(self, final_results):
|
200 |
-
times = ', '.join(map(str, self.stats['times']))
|
201 |
-
fname = join(self.cache_path, 'full.txt')
|
202 |
-
with open(fname, 'w') as f:
|
203 |
-
print(f"{final_results}" + '\n Times: ' + times, file=f)
|
204 |
-
|
205 |
-
zip_ = self.cache_path.rstrip('/') + '.zip'
|
206 |
-
archived = shutil.make_archive(self.cache_path.rstrip('/'), 'zip', self.cache_path)
|
207 |
-
try:
|
208 |
-
wandb.init(project="plonk")
|
209 |
-
artifact = wandb.Artifact('results', type='results')
|
210 |
-
artifact.add_file(zip_)
|
211 |
-
wandb.log_artifact(artifact)
|
212 |
-
wandb.finish()
|
213 |
-
except Exception:
|
214 |
-
print("Failed to log results to wandb")
|
215 |
-
pass
|
216 |
-
|
217 |
-
if os.path.isfile(zip_):
|
218 |
-
os.remove(zip_)
|
219 |
-
|
220 |
-
|
221 |
-
if __name__ == "__main__":
|
222 |
-
# login with the key from secret
|
223 |
-
wandb.login()
|
224 |
-
csv_str = os.environ['csv']
|
225 |
-
with open(CSV_FILE, 'w') as f:
|
226 |
-
f.write(csv_str)
|
227 |
-
|
228 |
-
import gradio as gr
|
229 |
-
def click(state, evt: gr.SelectData):
|
230 |
-
if state['clicked']:
|
231 |
-
return gr.update(), gr.update()
|
232 |
-
x, y = evt.index
|
233 |
-
state['clicked'] = True
|
234 |
-
image, text = state['engine'].click(x, y)
|
235 |
-
return gr.update(value=image), gr.update(value=text)
|
236 |
-
|
237 |
-
def next_(state):
|
238 |
-
if state['clicked']:
|
239 |
-
if state['engine'].isfinal():
|
240 |
-
text = state['engine'].finish()
|
241 |
-
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=text), gr.update(visible=False)
|
242 |
-
else:
|
243 |
-
fig, image, text = state['engine'].next_image()
|
244 |
-
state['clicked'] = False
|
245 |
-
return gr.update(value=fig), gr.update(value=image), gr.update(value=text), gr.update(), gr.update()
|
246 |
-
else:
|
247 |
-
return gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
|
248 |
-
|
249 |
-
def start(state):
|
250 |
-
# create a unique random temporary name under CACHE_DIR
|
251 |
-
# generate random hex and make sure it doesn't exist under CACHE_DIR
|
252 |
-
while True:
|
253 |
-
path = str(uuid.uuid4().hex)
|
254 |
-
name = os.path.join(RESULTS_DIR, path)
|
255 |
-
if not os.path.exists(name):
|
256 |
-
break
|
257 |
-
|
258 |
-
state['engine'] = Engine(IMAGE_FOLDER, CSV_FILE, name)
|
259 |
-
state['clicked'] = False
|
260 |
-
fig, image, text = state['engine'].load_image()
|
261 |
-
|
262 |
-
return (
|
263 |
-
gr.update(value=fig, visible=True),
|
264 |
-
gr.update(value=image, visible=True),
|
265 |
-
gr.update(value=text, visible=True),
|
266 |
-
gr.update(visible=True),
|
267 |
-
gr.update(visible=True),
|
268 |
-
gr.update(visible=False),
|
269 |
-
gr.update(visible=False),
|
270 |
-
gr.update(visible=False),
|
271 |
-
gr.update(visible=False),
|
272 |
-
)
|
273 |
-
|
274 |
-
with gr.Blocks() as demo:
|
275 |
-
state = gr.State({})
|
276 |
-
rules = gr.Markdown(RULES, visible=True)
|
277 |
-
|
278 |
-
start_button = gr.Button("Start", visible=True)
|
279 |
-
with gr.Row():
|
280 |
-
map_ = gr.Image(label='Map', visible=False)
|
281 |
-
image_ = gr.Image(label='Image', visible=False)
|
282 |
-
with gr.Row():
|
283 |
-
text = gr.Markdown("", visible=False)
|
284 |
-
text_count = gr.Markdown("", visible=False)
|
285 |
-
|
286 |
-
next_button = gr.Button("Next", visible=False)
|
287 |
-
start_button.click(start, inputs=[state], outputs=[map_, image_, text_count, text, next_button, rules, state, start_button])
|
288 |
-
map_.select(click, inputs=[state], outputs=[map_, text])
|
289 |
-
next_button.click(next_, inputs=[state], outputs=[map_, image_, text_count, text, next_button])
|
290 |
-
|
291 |
-
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|