Rahkakavee Baskaran commited on
Commit
4547220
·
1 Parent(s): 3592072
Files changed (1) hide show
  1. app.py +113 -13
app.py CHANGED
@@ -4,6 +4,9 @@ import json
4
  from itertools import islice
5
  from typing import Generator
6
  from plotly import express as px
 
 
 
7
 
8
 
9
  def chunks(data: dict, size=13) -> Generator:
@@ -55,12 +58,11 @@ def load_json(path: str) -> dict:
55
  # Load Data
56
  data = load_json("data.json")
57
  taxonomy = load_json("taxonomy_processed_v3.json")
 
58
 
59
  theme_counts = dict(Counter([el["THEMA"] for el in data]))
60
  labels_counts = dict(Counter([el["BEZEICHNUNG"] for el in data]))
61
 
62
- taxonomy = taxonomy
63
-
64
  names = [""]
65
  parents = ["Musterdatenkatalog"]
66
 
@@ -79,17 +81,6 @@ parents, name, values = get_tree_map_data(
79
  root="Musterdatenkatalog",
80
  )
81
 
82
-
83
- # fig = go.Figure(
84
- # go.Treemap(
85
- # labels=name,
86
- # parents=parents,
87
- # root_color="white",
88
- # values=values,
89
- # # textinfo="label+value",
90
- # ),
91
- # )
92
-
93
  fig = px.treemap(
94
  names=name,
95
  parents=parents,
@@ -103,6 +94,115 @@ fig.update_layout(
103
  )
104
 
105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  st.title("Musterdatenkatalog")
107
 
 
 
 
 
 
 
 
108
  st.plotly_chart(fig)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from itertools import islice
5
  from typing import Generator
6
  from plotly import express as px
7
+ from safetensors import safe_open
8
+ from semantic_search import predict
9
+ from sentence_transformers import SentenceTransformer
10
 
11
 
12
  def chunks(data: dict, size=13) -> Generator:
 
58
  # Load Data
59
  data = load_json("data.json")
60
  taxonomy = load_json("taxonomy_processed_v3.json")
61
+ taxonomy_labels = [el["group"] + " - " + el["label"] for el in taxonomy]
62
 
63
  theme_counts = dict(Counter([el["THEMA"] for el in data]))
64
  labels_counts = dict(Counter([el["BEZEICHNUNG"] for el in data]))
65
 
 
 
66
  names = [""]
67
  parents = ["Musterdatenkatalog"]
68
 
 
81
  root="Musterdatenkatalog",
82
  )
83
 
 
 
 
 
 
 
 
 
 
 
 
84
  fig = px.treemap(
85
  names=name,
86
  parents=parents,
 
94
  )
95
 
96
 
97
+ tensors = {}
98
+ with safe_open("corpus_embeddings.pt", framework="pt", device="cpu") as f:
99
+ for k in f.keys():
100
+ tensors[k] = f.get_tensor(k)
101
+
102
+ model = SentenceTransformer(
103
+ model_name_or_path="and-effect/musterdatenkatalog_clf",
104
+ device="cpu",
105
+ use_auth_token=True,
106
+ )
107
+
108
+
109
+ st.set_page_config(layout="wide")
110
+
111
  st.title("Musterdatenkatalog")
112
 
113
+ col1, col2, col3 = st.columns(3)
114
+ col1.metric("Kommunale Datensätze", len(data))
115
+ col2.metric("Themen", len(theme_counts))
116
+ col3.metric("Bezeichnungen", len(labels_counts))
117
+
118
+ st.title("Taxonomy")
119
+
120
  st.plotly_chart(fig)
121
+
122
+ st.title("Predict a Dataset")
123
+
124
+ # create two columns and make left column wider
125
+
126
+ # st.markdown(
127
+ # """
128
+ # <style>
129
+ # div[data-testid="stVerticalBlock"] div[style*="flex-direction: column;"] div[data-testid="stVerticalBlock"] {
130
+ # border-radius: 15px;
131
+ # background-color: white;
132
+ # box-shadow: 0 0 10px #eee;
133
+ # border: 1px solid #ddd;
134
+ # padding: 1rem;;
135
+ # }
136
+ # </style>
137
+ # """,
138
+ # unsafe_allow_html=True,
139
+ # )
140
+
141
+ st.markdown(
142
+ """
143
+ <style>
144
+ /* Style columns */
145
+ [data-testid="column"] {
146
+ border-radius: 15px;
147
+ background-color: white;
148
+ box-shadow: 0 0 10px #eee;
149
+ border: 1px solid #ddd;
150
+ padding: 1rem;;
151
+ }
152
+
153
+ /* Style containers */
154
+ [data-testid="stVerticalBlock"] > [style*="flex-direction: column;"] > [data-testid="stVerticalBlock"] {
155
+ border-radius: 15px;
156
+ background-color: white;
157
+ box-shadow: 0 0 10px #eee;
158
+ border: 1px solid #ddd;
159
+ padding: 1rem;;
160
+ }
161
+ </style>
162
+ """,
163
+ unsafe_allow_html=True,
164
+ )
165
+
166
+
167
+ col1, col2 = st.columns([1.2, 1])
168
+
169
+
170
+ with col2:
171
+ st.subheader("Example Datasets")
172
+ examples = [
173
+ "Spielplätze",
174
+ "Berliner Weihnachtsmärkte 2022",
175
+ "Hochschulwechslerquoten zum Masterstudium nach Bundesländern",
176
+ "Umringe der Bebauungspläne von Etgert",
177
+ ]
178
+
179
+ for example in examples:
180
+ if st.button(example):
181
+ if "key" not in st.session_state:
182
+ st.session_state["query"] = example
183
+
184
+
185
+ with col1:
186
+ if "query" not in st.session_state:
187
+ query = st.text_input(
188
+ "Enter dataset name",
189
+ )
190
+ if "query" in st.session_state and st.session_state.query in examples:
191
+ query = st.text_input("Enter dataset name", value=st.session_state.query)
192
+ if "query" in st.session_state and st.session_state.query not in examples:
193
+ del st.session_state["query"]
194
+ query = st.text_input("Enter dataset name")
195
+
196
+ top_k = st.select_slider("Top Results", options=[1, 2, 3, 4, 5], value=1)
197
+
198
+ predictions = predict(
199
+ query=query,
200
+ corpus_embeddings=tensors["corpus_embeddings"],
201
+ corpus_labels=taxonomy_labels,
202
+ top_k=top_k,
203
+ model=model,
204
+ )
205
+
206
+ if st.button("Predict"):
207
+ for prediction in predictions:
208
+ st.write(prediction)