Spaces:
Sleeping
Sleeping
File size: 4,379 Bytes
9b889da afe2deb 8a4658f afe2deb 38a655b 4b52e6f afe2deb cc6f22c 4b52e6f dbf177b cc6f22c 9b889da dbf177b 1de1740 db673be dbf177b db673be 9b889da 955fc23 14085cb 4184b6d 955fc23 be60ccb 955fc23 1de1740 be60ccb 1de1740 955fc23 f07227d dbbe37e f07227d 955fc23 6d49cf1 955fc23 f07227d cc6f22c 6d49cf1 cc6f22c 05501c8 cc6f22c 05501c8 cc6f22c 05501c8 cc6f22c 24d6812 4e3af53 24d6812 cc6f22c 24d6812 0e4e27b 24d6812 4b52e6f 05501c8 f07227d afe2deb cc6f22c 24d6812 8a4658f f07227d 8a4658f afe2deb 05501c8 f07227d 24d6812 dd76dcd cc6f22c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import os
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import joblib
import gradio as gr
import plotly.graph_objects as go
from io import BytesIO
from PIL import Image
from torchvision import transforms, models
from sklearn.preprocessing import LabelEncoder, MinMaxScaler
from gradio import Interface, Image, Label, HTML
from huggingface_hub import snapshot_download
import s2sphere
import folium
token = os.environ.get("token")
local_dir = snapshot_download(
repo_id="robocan/GeoG_23k",
repo_type="model",
local_dir="SVD",
token=token
)
device = 'cpu'
le = LabelEncoder()
le = joblib.load("SVD/le.gz")
len_classes = len(le.classes_) + 1
class ModelPre(torch.nn.Module):
def __init__(self):
super().__init__()
self.embedding = torch.nn.Sequential(
*list(models.convnext_small(weights=models.ConvNeXt_Small_Weights.IMAGENET1K_V1).children())[:-1],
torch.nn.Flatten(),
torch.nn.Linear(in_features=768, out_features=1024),
torch.nn.ReLU(),
torch.nn.Linear(in_features=1024, out_features=1024),
torch.nn.ReLU(),
torch.nn.Linear(in_features=1024, out_features=len_classes),
)
def forward(self, data):
return self.embedding(data)
# Load the pretrained model
model = ModelPre()
model_w = xser.load("SVD/GeoG.pth")['model'].to(device)
model.load_state_dict(model_w['model'])
cmp = transforms.Compose([
transforms.ToTensor(),
transforms.Resize(size=(224, 224), antialias=True),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def predict(input_img):
with torch.inference_mode():
img = cmp(input_img).unsqueeze(0)
res = model(img.to(device))
probabilities = torch.softmax(res, dim=1).cpu().numpy().flatten()
top_10_indices = np.argsort(probabilities)[-10:][::-1]
top_10_probabilities = probabilities[top_10_indices]
top_10_predictions = le.inverse_transform(top_10_indices)
results = {top_10_predictions[i]: float(top_10_probabilities[i]) for i in range(10)}
return results, top_10_predictions
# Function to get S2 cell polygon
def get_s2_cell_polygon(cell_id):
cell = s2sphere.Cell(s2sphere.CellId(cell_id))
vertices = []
for i in range(4):
vertex = s2sphere.LatLng.from_point(cell.get_vertex(i))
vertices.append((vertex.lat().degrees, vertex.lng().degrees))
vertices.append(vertices[0]) # Close the polygon
return vertices
def create_map_figure(predictions, cell_ids):
fig = go.Figure()
# Assign colors based on rank
colors = ['rgba(0, 255, 0, 0.2)'] * 3 + ['rgba(255, 255, 0, 0.2)'] * 7
for rank, cell_id in enumerate(cell_ids):
cell_id = int(cell_id)
polygon = get_s2_cell_polygon(cell_id)
lats, lons = zip(*polygon)
color = colors[rank]
fig.add_trace(go.Scattermapbox(
lat=lats,
lon=lons,
mode='lines',
fill='toself',
fillcolor=color,
line=dict(color='blue'),
name=f'Cell ID: {cell_id}'
))
fig.update_layout(
mapbox_style="open-street-map", # Change this line to use 'light' style
hovermode='closest',
mapbox=dict(
bearing=0,
center=go.layout.mapbox.Center(
lat=np.mean(lats),
lon=np.mean(lons)
),
pitch=0,
zoom=1
),
)
return fig
# Create label output function
def create_label_output(predictions):
results, cell_ids = predictions
fig = create_map_figure(results, cell_ids)
return fig
# Predict and plot function
def predict_and_plot(input_img):
predictions = predict(input_img)
return create_label_output(predictions)
# Gradio app definition
with gr.Blocks() as gradio_app:
with gr.Column():
input_image = gr.Image(label="Upload an Image", type="pil")
output_map = gr.Plot(label="Predicted Location on Map")
btn_predict = gr.Button("Predict")
btn_predict.click(predict_and_plot, inputs=input_image, outputs=output_map)
examples = ["GB.PNG", "IT.PNG", "NL.PNG", "NZ.PNG"]
gr.Examples(examples=examples, inputs=input_image)
gradio_app.launch()
|