Josephina commited on
Commit
567e1bb
·
1 Parent(s): 162ce19

app changed

Browse files
Files changed (1) hide show
  1. app.py +68 -21
app.py CHANGED
@@ -10,7 +10,6 @@ from safetensors import safe_open
10
  from sentence_transformers import SentenceTransformer
11
 
12
  from semantic_search import predict
13
- from utils.process_data import add_coor, load_data, merge_geoemtry
14
 
15
  HF_TOKEN = os.environ.get("HF_TOKEN")
16
  CITIES_ENRICHED = os.path.join("data", "cities_enriched_manually.csv")
@@ -151,8 +150,9 @@ fig = go.Figure(
151
  fig.update_layout(margin=dict(t=50, l=25, r=25, b=25))
152
  fig.update_layout(height=1000, width=1000, template="plotly")
153
 
 
154
  # load data ready to plot for local testing
155
- germany = pd.read_csv(MAP_PATH)
156
  # germany.drop(columns=["lat", "lon"], inplace=True)
157
 
158
  # # or generate it directly in this script
@@ -162,15 +162,60 @@ germany = pd.read_csv(MAP_PATH)
162
  # germany.to_csv(MAP_PATH_WITH_COORD, index=False)
163
 
164
  # # germany need columns with lat and lon as well as hover data
165
- fig_map = px.scatter_mapbox(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  germany,
167
  lat="lat",
168
  lon="lon",
169
  hover_name="ORG",
170
  custom_data=["Count"],
171
- # color_discrete_map=["magenta"],
172
- zoom=5,
173
- height=700,
174
  )
175
  # Custom hover template
176
  fig_map.update_traces(
@@ -181,22 +226,24 @@ fig_map.update_traces(
181
  ]
182
  )
183
  )
184
- fig_map.update_layout(mapbox_style="carto-positron")
185
-
186
- tensors = {}
187
- with safe_open("corpus_embeddings.pt", framework="pt", device="cpu") as f:
188
- for k in f.keys():
189
- tensors[k] = f.get_tensor(k)
190
-
191
- model = SentenceTransformer(
192
- model_name_or_path="and-effect/musterdatenkatalog_clf",
193
- device="cpu",
194
- use_auth_token=HF_TOKEN,
 
 
 
 
 
195
  )
196
 
197
-
198
- st.set_page_config(layout="wide")
199
-
200
  st.title("Musterdatenkatalog (MDK)")
201
 
202
  st.markdown(
@@ -356,5 +403,5 @@ st.markdown(
356
  """,
357
  unsafe_allow_html=True,
358
  )
359
- st.table(germany.head())
360
  st.plotly_chart(fig_map)
 
10
  from sentence_transformers import SentenceTransformer
11
 
12
  from semantic_search import predict
 
13
 
14
  HF_TOKEN = os.environ.get("HF_TOKEN")
15
  CITIES_ENRICHED = os.path.join("data", "cities_enriched_manually.csv")
 
150
  fig.update_layout(margin=dict(t=50, l=25, r=25, b=25))
151
  fig.update_layout(height=1000, width=1000, template="plotly")
152
 
153
+
154
  # load data ready to plot for local testing
155
+
156
  # germany.drop(columns=["lat", "lon"], inplace=True)
157
 
158
  # # or generate it directly in this script
 
162
  # germany.to_csv(MAP_PATH_WITH_COORD, index=False)
163
 
164
  # # germany need columns with lat and lon as well as hover data
165
+
166
+
167
+ tensors = {}
168
+ with safe_open("corpus_embeddings.pt", framework="pt", device="cpu") as f:
169
+ for k in f.keys():
170
+ tensors[k] = f.get_tensor(k)
171
+
172
+ model = SentenceTransformer(
173
+ model_name_or_path="and-effect/musterdatenkatalog_clf",
174
+ device="cpu",
175
+ use_auth_token=HF_TOKEN,
176
+ )
177
+
178
+
179
+ st.set_page_config(layout="wide")
180
+
181
+
182
+ @st.cache_data
183
+ def load_data() -> pd.DataFrame:
184
+ germany = pd.read_csv(MAP_PATH)
185
+ return germany
186
+
187
+
188
+ # germany = load_data()
189
+ # germany["lat"] = pd.to_numeric(germany["lat"])
190
+ # germany["lon"] = pd.to_numeric(germany["lon"])
191
+ germany = pd.DataFrame(
192
+ {
193
+ "ORG": [
194
+ "Berlin",
195
+ "Hamburg",
196
+ "München",
197
+ "Köln",
198
+ "Frankfurt am Main",
199
+ "Stuttgart",
200
+ "Düsseldorf",
201
+ "Leipzig",
202
+ "Dortmund",
203
+ "Essen",
204
+ ],
205
+ "lat": [52.52, 53.55, 48.14, 50.94, 50.11, 48.78, 51.22, 51.34, 51.51, 51.45],
206
+ "lon": [13.41, 9.99, 11.58, 6.96, 8.68, 9.18, 6.77, 12.37, 7.46, 7.01],
207
+ "Count": [1000, 800, 600, 500, 400, 300, 200, 150, 100, 50],
208
+ }
209
+ )
210
+
211
+
212
+ fig_map = px.scatter_geo(
213
  germany,
214
  lat="lat",
215
  lon="lon",
216
  hover_name="ORG",
217
  custom_data=["Count"],
218
+ scope="europe",
 
 
219
  )
220
  # Custom hover template
221
  fig_map.update_traces(
 
226
  ]
227
  )
228
  )
229
+ fig_map.update_layout(
230
+ geo=dict(
231
+ showland=True,
232
+ landcolor="LightGray",
233
+ showocean=True,
234
+ oceancolor="LightBlue",
235
+ # showcountries=True,
236
+ # countrycolor="Gray",
237
+ showsubunits=True,
238
+ # subunitcolor="Gray",
239
+ fitbounds="locations", # Fit the map bounds to the locations
240
+ lataxis=dict(range=[47, 55]), # Approximate latitude range for Germany
241
+ lonaxis=dict(range=[5, 16]), # Approximate longitude range for Germany
242
+ ),
243
+ mapbox_style="carto-positron",
244
+ height=700,
245
  )
246
 
 
 
 
247
  st.title("Musterdatenkatalog (MDK)")
248
 
249
  st.markdown(
 
403
  """,
404
  unsafe_allow_html=True,
405
  )
406
+
407
  st.plotly_chart(fig_map)