Spaces:
Running
Running
File size: 5,153 Bytes
9b889da afe2deb 8a4658f afe2deb 38a655b 4b52e6f afe2deb cc6f22c 4b52e6f dbf177b 4b67dd1 cc6f22c 9b889da dbf177b 1de1740 db673be dbf177b db673be 9b889da 955fc23 14085cb 4184b6d 955fc23 be60ccb 955fc23 1de1740 be60ccb 1de1740 955fc23 f07227d d497c1d f07227d 955fc23 6d49cf1 955fc23 f07227d cc6f22c 6d49cf1 a8f08c9 cc6f22c 05501c8 a8f08c9 05501c8 03ec0c3 cc6f22c 05501c8 a8f08c9 cc6f22c 05501c8 cc6f22c 3e8fcf7 cc6f22c 24d6812 a8f08c9 24d6812 a8f08c9 24d6812 a8f08c9 24d6812 a8f08c9 24d6812 4b52e6f 05501c8 f07227d afe2deb cc6f22c 24d6812 8a4658f a8f08c9 8a4658f a8f08c9 afe2deb 05501c8 f07227d 24d6812 a8f08c9 24d6812 a8f08c9 dd76dcd a8f08c9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import os
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import joblib
import gradio as gr
import plotly.graph_objects as go
from io import BytesIO
from PIL import Image
from torchvision import transforms, models
from sklearn.preprocessing import LabelEncoder, MinMaxScaler
from gradio import Interface, Image, Label, HTML
from huggingface_hub import snapshot_download
import torch_xla.utils.serialization as xser
import s2sphere
import folium
token = os.environ.get("token")
local_dir = snapshot_download(
repo_id="robocan/GeoG_23k",
repo_type="model",
local_dir="SVD",
token=token
)
device = 'cpu'
le = LabelEncoder()
le = joblib.load("SVD/le.gz")
len_classes = len(le.classes_) + 1
class ModelPre(torch.nn.Module):
def __init__(self):
super().__init__()
self.embedding = torch.nn.Sequential(
*list(models.convnext_small(weights=models.ConvNeXt_Small_Weights.IMAGENET1K_V1).children())[:-1],
torch.nn.Flatten(),
torch.nn.Linear(in_features=768, out_features=1024),
torch.nn.ReLU(),
torch.nn.Linear(in_features=1024, out_features=1024),
torch.nn.ReLU(),
torch.nn.Linear(in_features=1024, out_features=len_classes),
)
def forward(self, data):
return self.embedding(data)
# Load the pretrained model
model = ModelPre()
model_w = xser.load("SVD/GeoG.pth")
model.load_state_dict(model_w['model'])
cmp = transforms.Compose([
transforms.ToTensor(),
transforms.Resize(size=(224, 224), antialias=True),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def predict(input_img):
with torch.inference_mode():
img = cmp(input_img).unsqueeze(0)
res = model(img.to(device))
probabilities = torch.softmax(res, dim=1).cpu().numpy().flatten()
top_10_indices = np.argsort(probabilities)[-10:][::-1]
top_10_probabilities = probabilities[top_10_indices]
top_10_predictions = le.inverse_transform(top_10_indices)
results = {top_10_predictions[i]: float(top_10_probabilities[i]) for i in range(10)}
return results, top_10_predictions
# Function to get S2 cell polygon
def get_s2_cell_polygon(cell_id):
cell = s2sphere.Cell(s2sphere.CellId(cell_id))
vertices = []
for i in range(4):
vertex = s2sphere.LatLng.from_point(cell.get_vertex(i))
vertices.append((vertex.lat().degrees, vertex.lng().degrees))
vertices.append(vertices[0]) # Close the polygon
return vertices
def create_map_figure(predictions, cell_ids, selected_index=None):
fig = go.Figure()
# Assign colors based on rank
colors = ['rgba(0, 255, 0, 0.2)'] * 3 + ['rgba(255, 255, 0, 0.2)'] * 7
zoom_level = 1
center_lat = None
center_lon = None
for rank, cell_id in enumerate(cell_ids):
cell_id = int(float(cell_id))
polygon = get_s2_cell_polygon(cell_id)
lats, lons = zip(*polygon)
color = colors[rank]
fig.add_trace(go.Scattermapbox(
lat=lats,
lon=lons,
mode='lines',
fill='toself',
fillcolor=color,
line=dict(color='blue'),
name=f'Prediction {rank + 1}', # Updated label
))
# Set zoom based on the selected index
if selected_index is not None and rank == selected_index:
zoom_level = 10 # Adjust zoom level
center_lat = np.mean(lats)
center_lon = np.mean(lons)
fig.update_layout(
mapbox_style="open-street-map",
hovermode='closest',
mapbox=dict(
bearing=0,
center=go.layout.mapbox.Center(
lat=center_lat if center_lat else np.mean(lats),
lon=center_lon if center_lon else np.mean(lons)
),
pitch=0,
zoom=zoom_level # Zoom in if an index is selected
),
)
return fig
# Create label output function
def create_label_output(predictions):
results, cell_ids = predictions
fig = create_map_figure(results, cell_ids)
return fig
# Update the predict_and_plot function to handle zoom on selection
def predict_and_plot(input_img, selected_prediction):
predictions = predict(input_img)
return create_map_figure(predictions, predictions[1], selected_index=selected_prediction)
# Gradio app definition
with gr.Blocks() as gradio_app:
with gr.Column():
input_image = gr.Image(label="Upload an Image", type="pil")
selected_prediction = gr.Dropdown(choices=[f"Prediction {i+1}" for i in range(10)], label="Select Prediction to Zoom")
output_map = gr.Plot(label="Predicted Location on Map")
btn_predict = gr.Button("Predict")
# Update click function to include selected prediction
btn_predict.click(predict_and_plot, inputs=[input_image, selected_prediction], outputs=output_map)
examples = ["GB.PNG", "IT.PNG", "NL.PNG", "NZ.PNG"]
gr.Examples(examples=examples, inputs=input_image)
gradio_app.launch() |