robocan commited on
Commit
cc6f22c
·
verified ·
1 Parent(s): 24bb85c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -48
app.py CHANGED
@@ -8,17 +8,19 @@ import gradio as gr
8
  import plotly.graph_objects as go
9
  from io import BytesIO
10
  from PIL import Image
11
- from torchvision import transforms,models
12
- from sklearn.preprocessing import LabelEncoder,MinMaxScaler
13
  from gradio import Interface, Image, Label, HTML
14
  from huggingface_hub import snapshot_download
 
 
15
 
16
  # Retrieve the token from the environment variables
17
  token = os.environ.get("token")
18
 
19
  # Download the repository snapshot
20
  local_dir = snapshot_download(
21
- repo_id="robocan/GeoG_coordinate",
22
  repo_type="model",
23
  local_dir="SVD",
24
  token=token
@@ -27,7 +29,6 @@ local_dir = snapshot_download(
27
  device = 'cpu'
28
  le = LabelEncoder()
29
  le = joblib.load("SVD/le.gz")
30
- MMS = joblib.load("SVD/MMS.gz")
31
  len_classes = len(le.classes_) + 1
32
 
33
  class ModelPre(torch.nn.Module):
@@ -36,9 +37,9 @@ class ModelPre(torch.nn.Module):
36
  self.embedding = torch.nn.Sequential(
37
  *list(models.convnext_small(weights=models.ConvNeXt_Small_Weights.IMAGENET1K_V1).children())[:-1],
38
  torch.nn.Flatten(),
39
- torch.nn.Linear(in_features=768,out_features=512),
40
  torch.nn.ReLU(),
41
- torch.nn.Linear(in_features=512,out_features=len_classes),
42
  )
43
  # Freeze all layers
44
 
@@ -47,30 +48,8 @@ class ModelPre(torch.nn.Module):
47
 
48
  # Load the pretrained model
49
  model = ModelPre()
50
- #for param in model.parameters():
51
- # param.requires_grad = False
52
- class GeoGcord(torch.nn.Module):
53
- def __init__(self):
54
- super().__init__()
55
- self.embedding = torch.nn.Sequential(
56
- *list(model.children())[0][:-1],
57
- torch.nn.Linear(in_features=512,out_features=256),
58
- torch.nn.ReLU(),
59
- torch.nn.Linear(in_features=256,out_features=128),
60
- torch.nn.ReLU(),
61
- torch.nn.Linear(in_features=128,out_features=2),
62
- )
63
- # Freeze all layers
64
 
65
- def forward(self, data):
66
- return self.embedding(data)
67
-
68
-
69
-
70
- # Load the pre-trained model
71
- model = GeoGcord()
72
  model_w = torch.load("SVD/GeoG.pth", map_location=torch.device(device))
73
-
74
  model.load_state_dict(model_w['model'])
75
 
76
  cmp = transforms.Compose([
@@ -79,27 +58,45 @@ cmp = transforms.Compose([
79
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
80
  ])
81
 
82
- # Predict function for the new regression model
83
  def predict(input_img):
84
  with torch.inference_mode():
85
  img = cmp(input_img).unsqueeze(0)
86
  res = model(img.to(device))
87
- # Assuming res is a 2-layer regression output, and MMS.inverse_transform is needed
88
- prediction = MMS.inverse_transform(res.cpu().numpy()).flatten()
89
- return prediction
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  # Function to generate Plotly map figure
92
- def create_map_figure(lat, lon):
93
- fig = go.Figure(go.Scattermapbox(
94
- lat=[lat],
95
- lon=[lon],
96
- mode='markers',
97
- marker=go.scattermapbox.Marker(
98
- size=14
99
- ),
100
- text=[f'Lat: {lat}, Lon: {lon}'],
101
- hoverinfo='text'
102
- ))
 
 
 
 
 
103
 
104
  fig.update_layout(
105
  mapbox_style="open-street-map",
@@ -107,8 +104,8 @@ def create_map_figure(lat, lon):
107
  mapbox=dict(
108
  bearing=0,
109
  center=go.layout.mapbox.Center(
110
- lat=lat,
111
- lon=lon
112
  ),
113
  pitch=0,
114
  zoom=3
@@ -119,8 +116,8 @@ def create_map_figure(lat, lon):
119
 
120
  # Create label output function
121
  def create_label_output(predictions):
122
- lat, lon = predictions
123
- fig = create_map_figure(lat, lon)
124
  return fig
125
 
126
  # Predict and plot function
@@ -138,4 +135,4 @@ with gr.Blocks() as gradio_app:
138
  btn_predict.click(predict_and_plot, inputs=input_image, outputs=output_map)
139
  examples = ["GB.PNG", "IT.PNG", "NL.PNG", "NZ.PNG"]
140
  gr.Examples(examples=examples, inputs=input_image)
141
- gradio_app.launch()
 
8
  import plotly.graph_objects as go
9
  from io import BytesIO
10
  from PIL import Image
11
+ from torchvision import transforms, models
12
+ from sklearn.preprocessing import LabelEncoder, MinMaxScaler
13
  from gradio import Interface, Image, Label, HTML
14
  from huggingface_hub import snapshot_download
15
+ import s2sphere
16
+ import folium
17
 
18
  # Retrieve the token from the environment variables
19
  token = os.environ.get("token")
20
 
21
  # Download the repository snapshot
22
  local_dir = snapshot_download(
23
+ repo_id="robocan/GeoG-GCP",
24
  repo_type="model",
25
  local_dir="SVD",
26
  token=token
 
29
  device = 'cpu'
30
  le = LabelEncoder()
31
  le = joblib.load("SVD/le.gz")
 
32
  len_classes = len(le.classes_) + 1
33
 
34
  class ModelPre(torch.nn.Module):
 
37
  self.embedding = torch.nn.Sequential(
38
  *list(models.convnext_small(weights=models.ConvNeXt_Small_Weights.IMAGENET1K_V1).children())[:-1],
39
  torch.nn.Flatten(),
40
+ torch.nn.Linear(in_features=768, out_features=512),
41
  torch.nn.ReLU(),
42
+ torch.nn.Linear(in_features=512, out_features=len_classes),
43
  )
44
  # Freeze all layers
45
 
 
48
 
49
  # Load the pretrained model
50
  model = ModelPre()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
 
 
 
 
 
 
 
52
  model_w = torch.load("SVD/GeoG.pth", map_location=torch.device(device))
 
53
  model.load_state_dict(model_w['model'])
54
 
55
  cmp = transforms.Compose([
 
58
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
59
  ])
60
 
 
61
  def predict(input_img):
62
  with torch.inference_mode():
63
  img = cmp(input_img).unsqueeze(0)
64
  res = model(img.to(device))
65
+ probabilities = torch.softmax(res, dim=1).cpu().numpy().flatten()
66
+ top_10_indices = np.argsort(probabilities)[-10:][::-1]
67
+ top_10_probabilities = probabilities[top_10_indices]
68
+ top_10_predictions = le.inverse_transform(top_10_indices)
69
+
70
+ results = {top_10_predictions[i]: float(top_10_probabilities[i]) for i in range(10)}
71
+ return results, top_10_predictions
72
+
73
+ # Function to get S2 cell polygon
74
+ def get_s2_cell_polygon(cell_id):
75
+ cell = s2sphere.Cell(s2sphere.CellId(cell_id))
76
+ vertices = []
77
+ for i in range(4):
78
+ vertex = s2sphere.LatLng.from_point(cell.get_vertex(i))
79
+ vertices.append((vertex.lat().degrees, vertex.lng().degrees))
80
+ vertices.append(vertices[0]) # Close the polygon
81
+ return vertices
82
 
83
  # Function to generate Plotly map figure
84
+ def create_map_figure(predictions, cell_ids):
85
+ fig = go.Figure()
86
+
87
+ for cell_id in cell_ids:
88
+ cell_id = int(cell_id)
89
+ polygon = get_s2_cell_polygon(cell_id)
90
+ lats, lons = zip(*polygon)
91
+ fig.add_trace(go.Scattermapbox(
92
+ lat=lats,
93
+ lon=lons,
94
+ mode='lines',
95
+ fill='toself',
96
+ fillcolor='rgba(0, 0, 255, 0.2)',
97
+ line=dict(color='blue'),
98
+ name=f'Cell ID: {cell_id}'
99
+ ))
100
 
101
  fig.update_layout(
102
  mapbox_style="open-street-map",
 
104
  mapbox=dict(
105
  bearing=0,
106
  center=go.layout.mapbox.Center(
107
+ lat=np.mean(lats),
108
+ lon=np.mean(lons)
109
  ),
110
  pitch=0,
111
  zoom=3
 
116
 
117
  # Create label output function
118
  def create_label_output(predictions):
119
+ results, cell_ids = predictions
120
+ fig = create_map_figure(results, cell_ids)
121
  return fig
122
 
123
  # Predict and plot function
 
135
  btn_predict.click(predict_and_plot, inputs=input_image, outputs=output_map)
136
  examples = ["GB.PNG", "IT.PNG", "NL.PNG", "NZ.PNG"]
137
  gr.Examples(examples=examples, inputs=input_image)
138
+ gradio_app.launch()