Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -8,17 +8,19 @@ import gradio as gr
|
|
8 |
import plotly.graph_objects as go
|
9 |
from io import BytesIO
|
10 |
from PIL import Image
|
11 |
-
from torchvision import transforms,models
|
12 |
-
from sklearn.preprocessing import LabelEncoder,MinMaxScaler
|
13 |
from gradio import Interface, Image, Label, HTML
|
14 |
from huggingface_hub import snapshot_download
|
|
|
|
|
15 |
|
16 |
# Retrieve the token from the environment variables
|
17 |
token = os.environ.get("token")
|
18 |
|
19 |
# Download the repository snapshot
|
20 |
local_dir = snapshot_download(
|
21 |
-
repo_id="robocan/
|
22 |
repo_type="model",
|
23 |
local_dir="SVD",
|
24 |
token=token
|
@@ -27,7 +29,6 @@ local_dir = snapshot_download(
|
|
27 |
device = 'cpu'
|
28 |
le = LabelEncoder()
|
29 |
le = joblib.load("SVD/le.gz")
|
30 |
-
MMS = joblib.load("SVD/MMS.gz")
|
31 |
len_classes = len(le.classes_) + 1
|
32 |
|
33 |
class ModelPre(torch.nn.Module):
|
@@ -36,9 +37,9 @@ class ModelPre(torch.nn.Module):
|
|
36 |
self.embedding = torch.nn.Sequential(
|
37 |
*list(models.convnext_small(weights=models.ConvNeXt_Small_Weights.IMAGENET1K_V1).children())[:-1],
|
38 |
torch.nn.Flatten(),
|
39 |
-
torch.nn.Linear(in_features=768,out_features=512),
|
40 |
torch.nn.ReLU(),
|
41 |
-
torch.nn.Linear(in_features=512,out_features=len_classes),
|
42 |
)
|
43 |
# Freeze all layers
|
44 |
|
@@ -47,30 +48,8 @@ class ModelPre(torch.nn.Module):
|
|
47 |
|
48 |
# Load the pretrained model
|
49 |
model = ModelPre()
|
50 |
-
#for param in model.parameters():
|
51 |
-
# param.requires_grad = False
|
52 |
-
class GeoGcord(torch.nn.Module):
|
53 |
-
def __init__(self):
|
54 |
-
super().__init__()
|
55 |
-
self.embedding = torch.nn.Sequential(
|
56 |
-
*list(model.children())[0][:-1],
|
57 |
-
torch.nn.Linear(in_features=512,out_features=256),
|
58 |
-
torch.nn.ReLU(),
|
59 |
-
torch.nn.Linear(in_features=256,out_features=128),
|
60 |
-
torch.nn.ReLU(),
|
61 |
-
torch.nn.Linear(in_features=128,out_features=2),
|
62 |
-
)
|
63 |
-
# Freeze all layers
|
64 |
|
65 |
-
def forward(self, data):
|
66 |
-
return self.embedding(data)
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
# Load the pre-trained model
|
71 |
-
model = GeoGcord()
|
72 |
model_w = torch.load("SVD/GeoG.pth", map_location=torch.device(device))
|
73 |
-
|
74 |
model.load_state_dict(model_w['model'])
|
75 |
|
76 |
cmp = transforms.Compose([
|
@@ -79,27 +58,45 @@ cmp = transforms.Compose([
|
|
79 |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
80 |
])
|
81 |
|
82 |
-
# Predict function for the new regression model
|
83 |
def predict(input_img):
|
84 |
with torch.inference_mode():
|
85 |
img = cmp(input_img).unsqueeze(0)
|
86 |
res = model(img.to(device))
|
87 |
-
|
88 |
-
|
89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
|
91 |
# Function to generate Plotly map figure
|
92 |
-
def create_map_figure(
|
93 |
-
fig = go.Figure(
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
|
|
|
|
|
|
|
|
|
|
103 |
|
104 |
fig.update_layout(
|
105 |
mapbox_style="open-street-map",
|
@@ -107,8 +104,8 @@ def create_map_figure(lat, lon):
|
|
107 |
mapbox=dict(
|
108 |
bearing=0,
|
109 |
center=go.layout.mapbox.Center(
|
110 |
-
lat=
|
111 |
-
lon=
|
112 |
),
|
113 |
pitch=0,
|
114 |
zoom=3
|
@@ -119,8 +116,8 @@ def create_map_figure(lat, lon):
|
|
119 |
|
120 |
# Create label output function
|
121 |
def create_label_output(predictions):
|
122 |
-
|
123 |
-
fig = create_map_figure(
|
124 |
return fig
|
125 |
|
126 |
# Predict and plot function
|
@@ -138,4 +135,4 @@ with gr.Blocks() as gradio_app:
|
|
138 |
btn_predict.click(predict_and_plot, inputs=input_image, outputs=output_map)
|
139 |
examples = ["GB.PNG", "IT.PNG", "NL.PNG", "NZ.PNG"]
|
140 |
gr.Examples(examples=examples, inputs=input_image)
|
141 |
-
gradio_app.launch()
|
|
|
8 |
import plotly.graph_objects as go
|
9 |
from io import BytesIO
|
10 |
from PIL import Image
|
11 |
+
from torchvision import transforms, models
|
12 |
+
from sklearn.preprocessing import LabelEncoder, MinMaxScaler
|
13 |
from gradio import Interface, Image, Label, HTML
|
14 |
from huggingface_hub import snapshot_download
|
15 |
+
import s2sphere
|
16 |
+
import folium
|
17 |
|
18 |
# Retrieve the token from the environment variables
|
19 |
token = os.environ.get("token")
|
20 |
|
21 |
# Download the repository snapshot
|
22 |
local_dir = snapshot_download(
|
23 |
+
repo_id="robocan/GeoG-GCP",
|
24 |
repo_type="model",
|
25 |
local_dir="SVD",
|
26 |
token=token
|
|
|
29 |
device = 'cpu'
|
30 |
le = LabelEncoder()
|
31 |
le = joblib.load("SVD/le.gz")
|
|
|
32 |
len_classes = len(le.classes_) + 1
|
33 |
|
34 |
class ModelPre(torch.nn.Module):
|
|
|
37 |
self.embedding = torch.nn.Sequential(
|
38 |
*list(models.convnext_small(weights=models.ConvNeXt_Small_Weights.IMAGENET1K_V1).children())[:-1],
|
39 |
torch.nn.Flatten(),
|
40 |
+
torch.nn.Linear(in_features=768, out_features=512),
|
41 |
torch.nn.ReLU(),
|
42 |
+
torch.nn.Linear(in_features=512, out_features=len_classes),
|
43 |
)
|
44 |
# Freeze all layers
|
45 |
|
|
|
48 |
|
49 |
# Load the pretrained model
|
50 |
model = ModelPre()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
model_w = torch.load("SVD/GeoG.pth", map_location=torch.device(device))
|
|
|
53 |
model.load_state_dict(model_w['model'])
|
54 |
|
55 |
cmp = transforms.Compose([
|
|
|
58 |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
59 |
])
|
60 |
|
|
|
61 |
def predict(input_img):
|
62 |
with torch.inference_mode():
|
63 |
img = cmp(input_img).unsqueeze(0)
|
64 |
res = model(img.to(device))
|
65 |
+
probabilities = torch.softmax(res, dim=1).cpu().numpy().flatten()
|
66 |
+
top_10_indices = np.argsort(probabilities)[-10:][::-1]
|
67 |
+
top_10_probabilities = probabilities[top_10_indices]
|
68 |
+
top_10_predictions = le.inverse_transform(top_10_indices)
|
69 |
+
|
70 |
+
results = {top_10_predictions[i]: float(top_10_probabilities[i]) for i in range(10)}
|
71 |
+
return results, top_10_predictions
|
72 |
+
|
73 |
+
# Function to get S2 cell polygon
|
74 |
+
def get_s2_cell_polygon(cell_id):
|
75 |
+
cell = s2sphere.Cell(s2sphere.CellId(cell_id))
|
76 |
+
vertices = []
|
77 |
+
for i in range(4):
|
78 |
+
vertex = s2sphere.LatLng.from_point(cell.get_vertex(i))
|
79 |
+
vertices.append((vertex.lat().degrees, vertex.lng().degrees))
|
80 |
+
vertices.append(vertices[0]) # Close the polygon
|
81 |
+
return vertices
|
82 |
|
83 |
# Function to generate Plotly map figure
|
84 |
+
def create_map_figure(predictions, cell_ids):
|
85 |
+
fig = go.Figure()
|
86 |
+
|
87 |
+
for cell_id in cell_ids:
|
88 |
+
cell_id = int(cell_id)
|
89 |
+
polygon = get_s2_cell_polygon(cell_id)
|
90 |
+
lats, lons = zip(*polygon)
|
91 |
+
fig.add_trace(go.Scattermapbox(
|
92 |
+
lat=lats,
|
93 |
+
lon=lons,
|
94 |
+
mode='lines',
|
95 |
+
fill='toself',
|
96 |
+
fillcolor='rgba(0, 0, 255, 0.2)',
|
97 |
+
line=dict(color='blue'),
|
98 |
+
name=f'Cell ID: {cell_id}'
|
99 |
+
))
|
100 |
|
101 |
fig.update_layout(
|
102 |
mapbox_style="open-street-map",
|
|
|
104 |
mapbox=dict(
|
105 |
bearing=0,
|
106 |
center=go.layout.mapbox.Center(
|
107 |
+
lat=np.mean(lats),
|
108 |
+
lon=np.mean(lons)
|
109 |
),
|
110 |
pitch=0,
|
111 |
zoom=3
|
|
|
116 |
|
117 |
# Create label output function
|
118 |
def create_label_output(predictions):
|
119 |
+
results, cell_ids = predictions
|
120 |
+
fig = create_map_figure(results, cell_ids)
|
121 |
return fig
|
122 |
|
123 |
# Predict and plot function
|
|
|
135 |
btn_predict.click(predict_and_plot, inputs=input_image, outputs=output_map)
|
136 |
examples = ["GB.PNG", "IT.PNG", "NL.PNG", "NZ.PNG"]
|
137 |
gr.Examples(examples=examples, inputs=input_image)
|
138 |
+
gradio_app.launch()
|