bdpedigo commited on
Commit
48fe07c
·
verified ·
1 Parent(s): 24fe77d

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. README.md +396 -0
  2. config.json +360 -0
  3. local_compartment_classifier_bd_boxes.skops +0 -0
  4. train.py +359 -0
README.md ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ library_name: sklearn
4
+ tags:
5
+ - sklearn
6
+ - skops
7
+ - tabular-classification
8
+ model_format: skops
9
+ model_file: local_compartment_classifier_bd_boxes.skops
10
+ widget:
11
+ - structuredData:
12
+ area_nm2:
13
+ - 693824.0
14
+ - 4852608.0
15
+ - 17088896.0
16
+ area_nm2_neighbor_mean:
17
+ - 10181485.714285716
18
+ - 9884429.714285716
19
+ - 9010409.142857144
20
+ area_nm2_neighbor_std:
21
+ - 8312409.263207569
22
+ - 8587259.418816902
23
+ - 8418630.640116522
24
+ max_dt_nm:
25
+ - 69.0
26
+ - 543.0
27
+ - 1287.0
28
+ max_dt_nm_neighbor_mean:
29
+ - 664.7142857142857
30
+ - 630.8571428571429
31
+ - 577.7142857142857
32
+ max_dt_nm_neighbor_std:
33
+ - 479.64240342658945
34
+ - 504.9563358340017
35
+ - 468.41868657651344
36
+ mean_dt_nm:
37
+ - 24.4375
38
+ - 156.5
39
+ - 416.0
40
+ mean_dt_nm_neighbor_mean:
41
+ - 198.62946428571428
42
+ - 189.19642857142856
43
+ - 170.66071428571428
44
+ mean_dt_nm_neighbor_std:
45
+ - 150.614304054458
46
+ - 157.4368957825056
47
+ - 143.32375093543624
48
+ pca_ratio_01:
49
+ - 1.3849340770961909
50
+ - 1.181656878273399
51
+ - 1.128046800200765
52
+ pca_ratio_01_neighbor_mean:
53
+ - 1.8575624906424115
54
+ - 1.8760422359899387
55
+ - 1.880915879451087
56
+ pca_ratio_01_neighbor_std:
57
+ - 0.641580757345606
58
+ - 0.6228187048854344
59
+ - 0.6165585104590592
60
+ pca_unwrapped_0:
61
+ - -0.0046539306640625
62
+ - -0.497314453125
63
+ - -0.258544921875
64
+ pca_unwrapped_0_neighbor_mean:
65
+ - 0.039224624633789
66
+ - 0.0840119448575106
67
+ - 0.0623056238347833
68
+ pca_unwrapped_0_neighbor_std:
69
+ - 0.3114910605258688
70
+ - 0.2573427692683507
71
+ - 0.296254177168357
72
+ pca_unwrapped_1:
73
+ - 0.7392578125
74
+ - -0.11553955078125
75
+ - 0.2169189453125
76
+ pca_unwrapped_1_neighbor_mean:
77
+ - 0.0941687497225674
78
+ - 0.1718776009299538
79
+ - 0.1416541012850674
80
+ pca_unwrapped_1_neighbor_std:
81
+ - 0.3179467337379631
82
+ - 0.3628551035117971
83
+ - 0.372447324946889
84
+ pca_unwrapped_2:
85
+ - -0.673828125
86
+ - -0.85986328125
87
+ - 0.94140625
88
+ pca_unwrapped_2_neighbor_mean:
89
+ - 0.2258744673295454
90
+ - 0.2427867542613636
91
+ - 0.0790349786931818
92
+ pca_unwrapped_2_neighbor_std:
93
+ - 0.9134250264562896
94
+ - 0.8928014788058292
95
+ - 0.9167197839332804
96
+ pca_unwrapped_3:
97
+ - -0.0302886962890625
98
+ - -0.86572265625
99
+ - 0.57177734375
100
+ pca_unwrapped_3_neighbor_mean:
101
+ - -0.2933238636363636
102
+ - -0.2173753218217329
103
+ - -0.3480571400035511
104
+ pca_unwrapped_3_neighbor_std:
105
+ - 0.6203425764161097
106
+ - 0.5938304683645145
107
+ - 0.5600074530240728
108
+ pca_unwrapped_4:
109
+ - 0.67333984375
110
+ - -0.0005474090576171
111
+ - 0.81982421875
112
+ pca_unwrapped_4_neighbor_mean:
113
+ - 0.2915762121027166
114
+ - 0.3528386896306818
115
+ - 0.2782594507390802
116
+ pca_unwrapped_4_neighbor_std:
117
+ - 0.6415192812587974
118
+ - 0.6430080201673403
119
+ - 0.6308895861182334
120
+ pca_unwrapped_5:
121
+ - 0.73876953125
122
+ - 0.50048828125
123
+ - -0.03192138671875
124
+ pca_unwrapped_5_neighbor_mean:
125
+ - 0.2028697620738636
126
+ - 0.2245316938920454
127
+ - 0.2729325727982954
128
+ pca_unwrapped_5_neighbor_std:
129
+ - 0.265173781606759
130
+ - 0.2994363858938455
131
+ - 0.2968562365279343
132
+ pca_unwrapped_6:
133
+ - 0.99951171875
134
+ - 0.05828857421875
135
+ - -0.77880859375
136
+ pca_unwrapped_6_neighbor_mean:
137
+ - -0.2386505820534446
138
+ - -0.1530848416415128
139
+ - -0.0769850990988991
140
+ pca_unwrapped_6_neighbor_std:
141
+ - 0.6776577717043619
142
+ - 0.7717860533115238
143
+ - 0.7447135522384378
144
+ pca_unwrapped_7:
145
+ - 0.023834228515625
146
+ - -0.9931640625
147
+ - 0.52978515625
148
+ pca_unwrapped_7_neighbor_mean:
149
+ - -0.4803272594105113
150
+ - -0.3878728693181818
151
+ - -0.5263227982954546
152
+ pca_unwrapped_7_neighbor_std:
153
+ - 0.4799926318285017
154
+ - 0.4691567465869561
155
+ - 0.3891669942534205
156
+ pca_unwrapped_8:
157
+ - 0.0192413330078125
158
+ - 0.0997314453125
159
+ - -0.3359375
160
+ pca_unwrapped_8_neighbor_mean:
161
+ - -0.0384375832297585
162
+ - -0.0457548661665482
163
+ - -0.0061485984108664
164
+ pca_unwrapped_8_neighbor_std:
165
+ - 0.3037878488292577
166
+ - 0.3010843368506175
167
+ - 0.2874409267860334
168
+ pca_val_unwrapped_0:
169
+ - 15657.09765625
170
+ - 40668.40625
171
+ - 66863.0
172
+ pca_val_unwrapped_0_neighbor_mean:
173
+ - 69378.52059659091
174
+ - 67104.76526988637
175
+ - 64723.43856534091
176
+ pca_val_unwrapped_0_neighbor_std:
177
+ - 20242.245019019712
178
+ - 24702.906417865197
179
+ - 25959.16138296664
180
+ pca_val_unwrapped_1:
181
+ - 11305.3017578125
182
+ - 34416.42578125
183
+ - 59273.25
184
+ pca_val_unwrapped_1_neighbor_mean:
185
+ - 41190.40261008523
186
+ - 39089.39133522727
187
+ - 36829.68004261364
188
+ pca_val_unwrapped_1_neighbor_std:
189
+ - 16625.870141811894
190
+ - 18875.56976212627
191
+ - 17666.778281657556
192
+ pca_val_unwrapped_2:
193
+ - 1270.4095458984375
194
+ - 13551.6748046875
195
+ - 47764.625
196
+ pca_val_unwrapped_2_neighbor_mean:
197
+ - 28717.50048828125
198
+ - 27601.021828391335
199
+ - 24490.75362881747
200
+ pca_val_unwrapped_2_neighbor_std:
201
+ - 14988.204981576571
202
+ - 16601.48080038032
203
+ - 15622.078784778376
204
+ post_synapse_count:
205
+ - 0.0
206
+ - 0.0
207
+ - 0.0
208
+ post_synapse_count_neighbor_mean:
209
+ - 0.0
210
+ - 0.0
211
+ - 0.0
212
+ post_synapse_count_neighbor_std:
213
+ - 0.0
214
+ - 0.0
215
+ - 0.0
216
+ pre_synapse_count:
217
+ - 0.0
218
+ - 0.0
219
+ - 0.0
220
+ pre_synapse_count_neighbor_mean:
221
+ - 0.0
222
+ - 0.0
223
+ - 0.0
224
+ pre_synapse_count_neighbor_std:
225
+ - 0.0
226
+ - 0.0
227
+ - 0.0
228
+ size_nm3:
229
+ - 12771840.0
230
+ - 697943040.0
231
+ - 7550330880.0
232
+ size_nm3_neighbor_mean:
233
+ - 3233702034.285714
234
+ - 3184761234.285714
235
+ - 2695304960.0
236
+ size_nm3_neighbor_std:
237
+ - 3650678969.7909584
238
+ - 3691650923.5639486
239
+ - 3518520747.0511127
240
+ ---
241
+
242
+ # Model description
243
+
244
+ [More Information Needed]
245
+
246
+ ## Intended uses & limitations
247
+
248
+ [More Information Needed]
249
+
250
+ ## Training Procedure
251
+
252
+ [More Information Needed]
253
+
254
+ ### Hyperparameters
255
+
256
+ <details>
257
+ <summary> Click to expand </summary>
258
+
259
+ | Hyperparameter | Value |
260
+ |------------------------------------|---------------------------------------------------------------------------------------------------------------------------|
261
+ | memory | |
262
+ | steps | [('transformer', QuantileTransformer(output_distribution='normal')), ('lda', LinearDiscriminantAnalysis(n_components=3))] |
263
+ | verbose | False |
264
+ | transformer | QuantileTransformer(output_distribution='normal') |
265
+ | lda | LinearDiscriminantAnalysis(n_components=3) |
266
+ | transformer__copy | True |
267
+ | transformer__ignore_implicit_zeros | False |
268
+ | transformer__n_quantiles | 1000 |
269
+ | transformer__output_distribution | normal |
270
+ | transformer__random_state | |
271
+ | transformer__subsample | 10000 |
272
+ | lda__covariance_estimator | |
273
+ | lda__n_components | 3 |
274
+ | lda__priors | |
275
+ | lda__shrinkage | |
276
+ | lda__solver | svd |
277
+ | lda__store_covariance | False |
278
+ | lda__tol | 0.0001 |
279
+
280
+ </details>
281
+
282
+ ### Model Plot
283
+
284
+ <style>#sk-container-id-4 {/* Definition of color scheme common for light and dark mode */--sklearn-color-text: black;--sklearn-color-line: gray;/* Definition of color scheme for unfitted estimators */--sklearn-color-unfitted-level-0: #fff5e6;--sklearn-color-unfitted-level-1: #f6e4d2;--sklearn-color-unfitted-level-2: #ffe0b3;--sklearn-color-unfitted-level-3: chocolate;/* Definition of color scheme for fitted estimators */--sklearn-color-fitted-level-0: #f0f8ff;--sklearn-color-fitted-level-1: #d4ebff;--sklearn-color-fitted-level-2: #b3dbfd;--sklearn-color-fitted-level-3: cornflowerblue;/* Specific color for light theme */--sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));--sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));--sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));--sklearn-color-icon: #696969;@media (prefers-color-scheme: dark) {/* Redefinition of color scheme for dark theme */--sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));--sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));--sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));--sklearn-color-icon: #878787;}
285
+ }#sk-container-id-4 {color: var(--sklearn-color-text);
286
+ }#sk-container-id-4 pre {padding: 0;
287
+ }#sk-container-id-4 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;
288
+ }#sk-container-id-4 div.sk-dashed-wrapped {border: 1px dashed var(--sklearn-color-line);margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: var(--sklearn-color-background);
289
+ }#sk-container-id-4 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }`but bootstrap.min.css set `[hidden] { display: none !important; }`so we also need the `!important` here to be able to override thedefault hidden behavior on the sphinx rendered scikit-learn.org.See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;
290
+ }#sk-container-id-4 div.sk-text-repr-fallback {display: none;
291
+ }div.sk-parallel-item,
292
+ div.sk-serial,
293
+ div.sk-item {/* draw centered vertical line to link estimators */background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));background-size: 2px 100%;background-repeat: no-repeat;background-position: center center;
294
+ }/* Parallel-specific style estimator block */#sk-container-id-4 div.sk-parallel-item::after {content: "";width: 100%;border-bottom: 2px solid var(--sklearn-color-text-on-default-background);flex-grow: 1;
295
+ }#sk-container-id-4 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: var(--sklearn-color-background);position: relative;
296
+ }#sk-container-id-4 div.sk-parallel-item {display: flex;flex-direction: column;
297
+ }#sk-container-id-4 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;
298
+ }#sk-container-id-4 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;
299
+ }#sk-container-id-4 div.sk-parallel-item:only-child::after {width: 0;
300
+ }/* Serial-specific style estimator block */#sk-container-id-4 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: var(--sklearn-color-background);padding-right: 1em;padding-left: 1em;
301
+ }/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is
302
+ clickable and can be expanded/collapsed.
303
+ - Pipeline and ColumnTransformer use this feature and define the default style
304
+ - Estimators will overwrite some part of the style using the `sk-estimator` class
305
+ *//* Pipeline and ColumnTransformer style (default) */#sk-container-id-4 div.sk-toggleable {/* Default theme specific background. It is overwritten whether we have aspecific estimator or a Pipeline/ColumnTransformer */background-color: var(--sklearn-color-background);
306
+ }/* Toggleable label */
307
+ #sk-container-id-4 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.5em;box-sizing: border-box;text-align: center;
308
+ }#sk-container-id-4 label.sk-toggleable__label-arrow:before {/* Arrow on the left of the label */content: "▸";float: left;margin-right: 0.25em;color: var(--sklearn-color-icon);
309
+ }#sk-container-id-4 label.sk-toggleable__label-arrow:hover:before {color: var(--sklearn-color-text);
310
+ }/* Toggleable content - dropdown */#sk-container-id-4 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;/* unfitted */background-color: var(--sklearn-color-unfitted-level-0);
311
+ }#sk-container-id-4 div.sk-toggleable__content.fitted {/* fitted */background-color: var(--sklearn-color-fitted-level-0);
312
+ }#sk-container-id-4 div.sk-toggleable__content pre {margin: 0.2em;border-radius: 0.25em;color: var(--sklearn-color-text);/* unfitted */background-color: var(--sklearn-color-unfitted-level-0);
313
+ }#sk-container-id-4 div.sk-toggleable__content.fitted pre {/* unfitted */background-color: var(--sklearn-color-fitted-level-0);
314
+ }#sk-container-id-4 input.sk-toggleable__control:checked~div.sk-toggleable__content {/* Expand drop-down */max-height: 200px;max-width: 100%;overflow: auto;
315
+ }#sk-container-id-4 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: "▾";
316
+ }/* Pipeline/ColumnTransformer-specific style */#sk-container-id-4 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {color: var(--sklearn-color-text);background-color: var(--sklearn-color-unfitted-level-2);
317
+ }#sk-container-id-4 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: var(--sklearn-color-fitted-level-2);
318
+ }/* Estimator-specific style *//* Colorize estimator box */
319
+ #sk-container-id-4 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {/* unfitted */background-color: var(--sklearn-color-unfitted-level-2);
320
+ }#sk-container-id-4 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {/* fitted */background-color: var(--sklearn-color-fitted-level-2);
321
+ }#sk-container-id-4 div.sk-label label.sk-toggleable__label,
322
+ #sk-container-id-4 div.sk-label label {/* The background is the default theme color */color: var(--sklearn-color-text-on-default-background);
323
+ }/* On hover, darken the color of the background */
324
+ #sk-container-id-4 div.sk-label:hover label.sk-toggleable__label {color: var(--sklearn-color-text);background-color: var(--sklearn-color-unfitted-level-2);
325
+ }/* Label box, darken color on hover, fitted */
326
+ #sk-container-id-4 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {color: var(--sklearn-color-text);background-color: var(--sklearn-color-fitted-level-2);
327
+ }/* Estimator label */#sk-container-id-4 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;
328
+ }#sk-container-id-4 div.sk-label-container {text-align: center;
329
+ }/* Estimator-specific */
330
+ #sk-container-id-4 div.sk-estimator {font-family: monospace;border: 1px dotted var(--sklearn-color-border-box);border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;/* unfitted */background-color: var(--sklearn-color-unfitted-level-0);
331
+ }#sk-container-id-4 div.sk-estimator.fitted {/* fitted */background-color: var(--sklearn-color-fitted-level-0);
332
+ }/* on hover */
333
+ #sk-container-id-4 div.sk-estimator:hover {/* unfitted */background-color: var(--sklearn-color-unfitted-level-2);
334
+ }#sk-container-id-4 div.sk-estimator.fitted:hover {/* fitted */background-color: var(--sklearn-color-fitted-level-2);
335
+ }/* Specification for estimator info (e.g. "i" and "?") *//* Common style for "i" and "?" */.sk-estimator-doc-link,
336
+ a:link.sk-estimator-doc-link,
337
+ a:visited.sk-estimator-doc-link {float: right;font-size: smaller;line-height: 1em;font-family: monospace;background-color: var(--sklearn-color-background);border-radius: 1em;height: 1em;width: 1em;text-decoration: none !important;margin-left: 1ex;/* unfitted */border: var(--sklearn-color-unfitted-level-1) 1pt solid;color: var(--sklearn-color-unfitted-level-1);
338
+ }.sk-estimator-doc-link.fitted,
339
+ a:link.sk-estimator-doc-link.fitted,
340
+ a:visited.sk-estimator-doc-link.fitted {/* fitted */border: var(--sklearn-color-fitted-level-1) 1pt solid;color: var(--sklearn-color-fitted-level-1);
341
+ }/* On hover */
342
+ div.sk-estimator:hover .sk-estimator-doc-link:hover,
343
+ .sk-estimator-doc-link:hover,
344
+ div.sk-label-container:hover .sk-estimator-doc-link:hover,
345
+ .sk-estimator-doc-link:hover {/* unfitted */background-color: var(--sklearn-color-unfitted-level-3);color: var(--sklearn-color-background);text-decoration: none;
346
+ }div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,
347
+ .sk-estimator-doc-link.fitted:hover,
348
+ div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,
349
+ .sk-estimator-doc-link.fitted:hover {/* fitted */background-color: var(--sklearn-color-fitted-level-3);color: var(--sklearn-color-background);text-decoration: none;
350
+ }/* Span, style for the box shown on hovering the info icon */
351
+ .sk-estimator-doc-link span {display: none;z-index: 9999;position: relative;font-weight: normal;right: .2ex;padding: .5ex;margin: .5ex;width: min-content;min-width: 20ex;max-width: 50ex;color: var(--sklearn-color-text);box-shadow: 2pt 2pt 4pt #999;/* unfitted */background: var(--sklearn-color-unfitted-level-0);border: .5pt solid var(--sklearn-color-unfitted-level-3);
352
+ }.sk-estimator-doc-link.fitted span {/* fitted */background: var(--sklearn-color-fitted-level-0);border: var(--sklearn-color-fitted-level-3);
353
+ }.sk-estimator-doc-link:hover span {display: block;
354
+ }/* "?"-specific style due to the `<a>` HTML tag */#sk-container-id-4 a.estimator_doc_link {float: right;font-size: 1rem;line-height: 1em;font-family: monospace;background-color: var(--sklearn-color-background);border-radius: 1rem;height: 1rem;width: 1rem;text-decoration: none;/* unfitted */color: var(--sklearn-color-unfitted-level-1);border: var(--sklearn-color-unfitted-level-1) 1pt solid;
355
+ }#sk-container-id-4 a.estimator_doc_link.fitted {/* fitted */border: var(--sklearn-color-fitted-level-1) 1pt solid;color: var(--sklearn-color-fitted-level-1);
356
+ }/* On hover */
357
+ #sk-container-id-4 a.estimator_doc_link:hover {/* unfitted */background-color: var(--sklearn-color-unfitted-level-3);color: var(--sklearn-color-background);text-decoration: none;
358
+ }#sk-container-id-4 a.estimator_doc_link.fitted:hover {/* fitted */background-color: var(--sklearn-color-fitted-level-3);
359
+ }
360
+ </style><div id="sk-container-id-4" class="sk-top-container" style="overflow: auto;"><div class="sk-text-repr-fallback"><pre>Pipeline(steps=[(&#x27;transformer&#x27;,QuantileTransformer(output_distribution=&#x27;normal&#x27;)),(&#x27;lda&#x27;, LinearDiscriminantAnalysis(n_components=3))])</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class="sk-container" hidden><div class="sk-item sk-dashed-wrapped"><div class="sk-label-container"><div class="sk-label fitted sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="sk-estimator-id-10" type="checkbox" ><label for="sk-estimator-id-10" class="sk-toggleable__label fitted sk-toggleable__label-arrow fitted">&nbsp;&nbsp;Pipeline<a class="sk-estimator-doc-link fitted" rel="noreferrer" target="_blank" href="https://scikit-learn.org/1.4/modules/generated/sklearn.pipeline.Pipeline.html">?<span>Documentation for Pipeline</span></a><span class="sk-estimator-doc-link fitted">i<span>Fitted</span></span></label><div class="sk-toggleable__content fitted"><pre>Pipeline(steps=[(&#x27;transformer&#x27;,QuantileTransformer(output_distribution=&#x27;normal&#x27;)),(&#x27;lda&#x27;, LinearDiscriminantAnalysis(n_components=3))])</pre></div> </div></div><div class="sk-serial"><div class="sk-item"><div class="sk-estimator fitted sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="sk-estimator-id-11" type="checkbox" ><label for="sk-estimator-id-11" class="sk-toggleable__label fitted sk-toggleable__label-arrow fitted">&nbsp;QuantileTransformer<a class="sk-estimator-doc-link fitted" rel="noreferrer" target="_blank" href="https://scikit-learn.org/1.4/modules/generated/sklearn.preprocessing.QuantileTransformer.html">?<span>Documentation for QuantileTransformer</span></a></label><div class="sk-toggleable__content fitted"><pre>QuantileTransformer(output_distribution=&#x27;normal&#x27;)</pre></div> </div></div><div class="sk-item"><div class="sk-estimator fitted sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="sk-estimator-id-12" type="checkbox" ><label for="sk-estimator-id-12" class="sk-toggleable__label fitted sk-toggleable__label-arrow fitted">&nbsp;LinearDiscriminantAnalysis<a class="sk-estimator-doc-link fitted" rel="noreferrer" target="_blank" href="https://scikit-learn.org/1.4/modules/generated/sklearn.discriminant_analysis.LinearDiscriminantAnalysis.html">?<span>Documentation for LinearDiscriminantAnalysis</span></a></label><div class="sk-toggleable__content fitted"><pre>LinearDiscriminantAnalysis(n_components=3)</pre></div> </div></div></div></div></div></div>
361
+
362
+ ## Evaluation Results
363
+
364
+ [More Information Needed]
365
+
366
+ # How to Get Started with the Model
367
+
368
+ [More Information Needed]
369
+
370
+ # Model Card Authors
371
+
372
+ This model card is written by following authors:
373
+
374
+ [More Information Needed]
375
+
376
+ # Model Card Contact
377
+
378
+ You can contact the model card authors through following channels:
379
+ [More Information Needed]
380
+
381
+ # Citation
382
+
383
+ Below you can find information related to citation.
384
+
385
+ **BibTeX:**
386
+ ```
387
+ [More Information Needed]
388
+ ```
389
+
390
+ # model_card_authors
391
+
392
+ bdpedigo
393
+
394
+ # model_description
395
+
396
+ This is a model trained to classify pieces of neuron as axon, dendrite, soma, orglia, based only on their local shape and synapse features.The model is a linear discriminant classifier which was trained on compartment labels generated by Bethanny Danskin for 3 6x6x6 um boxes in the Minnie65 Phase3 dataset.
config.json ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "sklearn": {
3
+ "columns": [
4
+ "area_nm2",
5
+ "max_dt_nm",
6
+ "mean_dt_nm",
7
+ "size_nm3",
8
+ "pca_unwrapped_0",
9
+ "pca_unwrapped_1",
10
+ "pca_unwrapped_2",
11
+ "pca_unwrapped_3",
12
+ "pca_unwrapped_4",
13
+ "pca_unwrapped_5",
14
+ "pca_unwrapped_6",
15
+ "pca_unwrapped_7",
16
+ "pca_unwrapped_8",
17
+ "pca_val_unwrapped_0",
18
+ "pca_val_unwrapped_1",
19
+ "pca_val_unwrapped_2",
20
+ "pca_ratio_01",
21
+ "pre_synapse_count",
22
+ "post_synapse_count",
23
+ "area_nm2_neighbor_mean",
24
+ "area_nm2_neighbor_std",
25
+ "max_dt_nm_neighbor_mean",
26
+ "max_dt_nm_neighbor_std",
27
+ "mean_dt_nm_neighbor_mean",
28
+ "mean_dt_nm_neighbor_std",
29
+ "size_nm3_neighbor_mean",
30
+ "size_nm3_neighbor_std",
31
+ "pca_unwrapped_0_neighbor_mean",
32
+ "pca_unwrapped_0_neighbor_std",
33
+ "pca_unwrapped_1_neighbor_mean",
34
+ "pca_unwrapped_1_neighbor_std",
35
+ "pca_unwrapped_2_neighbor_mean",
36
+ "pca_unwrapped_2_neighbor_std",
37
+ "pca_unwrapped_3_neighbor_mean",
38
+ "pca_unwrapped_3_neighbor_std",
39
+ "pca_unwrapped_4_neighbor_mean",
40
+ "pca_unwrapped_4_neighbor_std",
41
+ "pca_unwrapped_5_neighbor_mean",
42
+ "pca_unwrapped_5_neighbor_std",
43
+ "pca_unwrapped_6_neighbor_mean",
44
+ "pca_unwrapped_6_neighbor_std",
45
+ "pca_unwrapped_7_neighbor_mean",
46
+ "pca_unwrapped_7_neighbor_std",
47
+ "pca_unwrapped_8_neighbor_mean",
48
+ "pca_unwrapped_8_neighbor_std",
49
+ "pca_val_unwrapped_0_neighbor_mean",
50
+ "pca_val_unwrapped_0_neighbor_std",
51
+ "pca_val_unwrapped_1_neighbor_mean",
52
+ "pca_val_unwrapped_1_neighbor_std",
53
+ "pca_val_unwrapped_2_neighbor_mean",
54
+ "pca_val_unwrapped_2_neighbor_std",
55
+ "pca_ratio_01_neighbor_mean",
56
+ "pca_ratio_01_neighbor_std",
57
+ "pre_synapse_count_neighbor_mean",
58
+ "pre_synapse_count_neighbor_std",
59
+ "post_synapse_count_neighbor_mean",
60
+ "post_synapse_count_neighbor_std"
61
+ ],
62
+ "environment": [
63
+ "scikit-learn",
64
+ "caveclient"
65
+ ],
66
+ "example_input": {
67
+ "area_nm2": [
68
+ 693824.0,
69
+ 4852608.0,
70
+ 17088896.0
71
+ ],
72
+ "area_nm2_neighbor_mean": [
73
+ 10181485.714285716,
74
+ 9884429.714285716,
75
+ 9010409.142857144
76
+ ],
77
+ "area_nm2_neighbor_std": [
78
+ 8312409.263207569,
79
+ 8587259.418816902,
80
+ 8418630.640116522
81
+ ],
82
+ "max_dt_nm": [
83
+ 69.0,
84
+ 543.0,
85
+ 1287.0
86
+ ],
87
+ "max_dt_nm_neighbor_mean": [
88
+ 664.7142857142857,
89
+ 630.8571428571429,
90
+ 577.7142857142857
91
+ ],
92
+ "max_dt_nm_neighbor_std": [
93
+ 479.64240342658945,
94
+ 504.9563358340017,
95
+ 468.41868657651344
96
+ ],
97
+ "mean_dt_nm": [
98
+ 24.4375,
99
+ 156.5,
100
+ 416.0
101
+ ],
102
+ "mean_dt_nm_neighbor_mean": [
103
+ 198.62946428571428,
104
+ 189.19642857142856,
105
+ 170.66071428571428
106
+ ],
107
+ "mean_dt_nm_neighbor_std": [
108
+ 150.614304054458,
109
+ 157.4368957825056,
110
+ 143.32375093543624
111
+ ],
112
+ "pca_ratio_01": [
113
+ 1.3849340770961909,
114
+ 1.181656878273399,
115
+ 1.128046800200765
116
+ ],
117
+ "pca_ratio_01_neighbor_mean": [
118
+ 1.8575624906424115,
119
+ 1.8760422359899387,
120
+ 1.880915879451087
121
+ ],
122
+ "pca_ratio_01_neighbor_std": [
123
+ 0.641580757345606,
124
+ 0.6228187048854344,
125
+ 0.6165585104590592
126
+ ],
127
+ "pca_unwrapped_0": [
128
+ -0.0046539306640625,
129
+ -0.497314453125,
130
+ -0.258544921875
131
+ ],
132
+ "pca_unwrapped_0_neighbor_mean": [
133
+ 0.039224624633789,
134
+ 0.0840119448575106,
135
+ 0.0623056238347833
136
+ ],
137
+ "pca_unwrapped_0_neighbor_std": [
138
+ 0.3114910605258688,
139
+ 0.2573427692683507,
140
+ 0.296254177168357
141
+ ],
142
+ "pca_unwrapped_1": [
143
+ 0.7392578125,
144
+ -0.11553955078125,
145
+ 0.2169189453125
146
+ ],
147
+ "pca_unwrapped_1_neighbor_mean": [
148
+ 0.0941687497225674,
149
+ 0.1718776009299538,
150
+ 0.1416541012850674
151
+ ],
152
+ "pca_unwrapped_1_neighbor_std": [
153
+ 0.3179467337379631,
154
+ 0.3628551035117971,
155
+ 0.372447324946889
156
+ ],
157
+ "pca_unwrapped_2": [
158
+ -0.673828125,
159
+ -0.85986328125,
160
+ 0.94140625
161
+ ],
162
+ "pca_unwrapped_2_neighbor_mean": [
163
+ 0.2258744673295454,
164
+ 0.2427867542613636,
165
+ 0.0790349786931818
166
+ ],
167
+ "pca_unwrapped_2_neighbor_std": [
168
+ 0.9134250264562896,
169
+ 0.8928014788058292,
170
+ 0.9167197839332804
171
+ ],
172
+ "pca_unwrapped_3": [
173
+ -0.0302886962890625,
174
+ -0.86572265625,
175
+ 0.57177734375
176
+ ],
177
+ "pca_unwrapped_3_neighbor_mean": [
178
+ -0.2933238636363636,
179
+ -0.2173753218217329,
180
+ -0.3480571400035511
181
+ ],
182
+ "pca_unwrapped_3_neighbor_std": [
183
+ 0.6203425764161097,
184
+ 0.5938304683645145,
185
+ 0.5600074530240728
186
+ ],
187
+ "pca_unwrapped_4": [
188
+ 0.67333984375,
189
+ -0.0005474090576171,
190
+ 0.81982421875
191
+ ],
192
+ "pca_unwrapped_4_neighbor_mean": [
193
+ 0.2915762121027166,
194
+ 0.3528386896306818,
195
+ 0.2782594507390802
196
+ ],
197
+ "pca_unwrapped_4_neighbor_std": [
198
+ 0.6415192812587974,
199
+ 0.6430080201673403,
200
+ 0.6308895861182334
201
+ ],
202
+ "pca_unwrapped_5": [
203
+ 0.73876953125,
204
+ 0.50048828125,
205
+ -0.03192138671875
206
+ ],
207
+ "pca_unwrapped_5_neighbor_mean": [
208
+ 0.2028697620738636,
209
+ 0.2245316938920454,
210
+ 0.2729325727982954
211
+ ],
212
+ "pca_unwrapped_5_neighbor_std": [
213
+ 0.265173781606759,
214
+ 0.2994363858938455,
215
+ 0.2968562365279343
216
+ ],
217
+ "pca_unwrapped_6": [
218
+ 0.99951171875,
219
+ 0.05828857421875,
220
+ -0.77880859375
221
+ ],
222
+ "pca_unwrapped_6_neighbor_mean": [
223
+ -0.2386505820534446,
224
+ -0.1530848416415128,
225
+ -0.0769850990988991
226
+ ],
227
+ "pca_unwrapped_6_neighbor_std": [
228
+ 0.6776577717043619,
229
+ 0.7717860533115238,
230
+ 0.7447135522384378
231
+ ],
232
+ "pca_unwrapped_7": [
233
+ 0.023834228515625,
234
+ -0.9931640625,
235
+ 0.52978515625
236
+ ],
237
+ "pca_unwrapped_7_neighbor_mean": [
238
+ -0.4803272594105113,
239
+ -0.3878728693181818,
240
+ -0.5263227982954546
241
+ ],
242
+ "pca_unwrapped_7_neighbor_std": [
243
+ 0.4799926318285017,
244
+ 0.4691567465869561,
245
+ 0.3891669942534205
246
+ ],
247
+ "pca_unwrapped_8": [
248
+ 0.0192413330078125,
249
+ 0.0997314453125,
250
+ -0.3359375
251
+ ],
252
+ "pca_unwrapped_8_neighbor_mean": [
253
+ -0.0384375832297585,
254
+ -0.0457548661665482,
255
+ -0.0061485984108664
256
+ ],
257
+ "pca_unwrapped_8_neighbor_std": [
258
+ 0.3037878488292577,
259
+ 0.3010843368506175,
260
+ 0.2874409267860334
261
+ ],
262
+ "pca_val_unwrapped_0": [
263
+ 15657.09765625,
264
+ 40668.40625,
265
+ 66863.0
266
+ ],
267
+ "pca_val_unwrapped_0_neighbor_mean": [
268
+ 69378.52059659091,
269
+ 67104.76526988637,
270
+ 64723.43856534091
271
+ ],
272
+ "pca_val_unwrapped_0_neighbor_std": [
273
+ 20242.245019019712,
274
+ 24702.906417865197,
275
+ 25959.16138296664
276
+ ],
277
+ "pca_val_unwrapped_1": [
278
+ 11305.3017578125,
279
+ 34416.42578125,
280
+ 59273.25
281
+ ],
282
+ "pca_val_unwrapped_1_neighbor_mean": [
283
+ 41190.40261008523,
284
+ 39089.39133522727,
285
+ 36829.68004261364
286
+ ],
287
+ "pca_val_unwrapped_1_neighbor_std": [
288
+ 16625.870141811894,
289
+ 18875.56976212627,
290
+ 17666.778281657556
291
+ ],
292
+ "pca_val_unwrapped_2": [
293
+ 1270.4095458984375,
294
+ 13551.6748046875,
295
+ 47764.625
296
+ ],
297
+ "pca_val_unwrapped_2_neighbor_mean": [
298
+ 28717.50048828125,
299
+ 27601.021828391335,
300
+ 24490.75362881747
301
+ ],
302
+ "pca_val_unwrapped_2_neighbor_std": [
303
+ 14988.204981576571,
304
+ 16601.48080038032,
305
+ 15622.078784778376
306
+ ],
307
+ "post_synapse_count": [
308
+ 0.0,
309
+ 0.0,
310
+ 0.0
311
+ ],
312
+ "post_synapse_count_neighbor_mean": [
313
+ 0.0,
314
+ 0.0,
315
+ 0.0
316
+ ],
317
+ "post_synapse_count_neighbor_std": [
318
+ 0.0,
319
+ 0.0,
320
+ 0.0
321
+ ],
322
+ "pre_synapse_count": [
323
+ 0.0,
324
+ 0.0,
325
+ 0.0
326
+ ],
327
+ "pre_synapse_count_neighbor_mean": [
328
+ 0.0,
329
+ 0.0,
330
+ 0.0
331
+ ],
332
+ "pre_synapse_count_neighbor_std": [
333
+ 0.0,
334
+ 0.0,
335
+ 0.0
336
+ ],
337
+ "size_nm3": [
338
+ 12771840.0,
339
+ 697943040.0,
340
+ 7550330880.0
341
+ ],
342
+ "size_nm3_neighbor_mean": [
343
+ 3233702034.285714,
344
+ 3184761234.285714,
345
+ 2695304960.0
346
+ ],
347
+ "size_nm3_neighbor_std": [
348
+ 3650678969.7909584,
349
+ 3691650923.5639486,
350
+ 3518520747.0511127
351
+ ]
352
+ },
353
+ "model": {
354
+ "file": "local_compartment_classifier_bd_boxes.skops"
355
+ },
356
+ "model_format": "skops",
357
+ "task": "tabular-classification",
358
+ "use_intelex": false
359
+ }
360
+ }
local_compartment_classifier_bd_boxes.skops ADDED
Binary file (515 kB). View file
 
train.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%
2
+
3
+
4
+ from pathlib import Path
5
+
6
+ import caveclient as cc
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ import pandas as pd
10
+ import seaborn as sns
11
+ from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
12
+ from sklearn.ensemble import RandomForestClassifier
13
+ from sklearn.metrics import classification_report
14
+ from sklearn.model_selection import KFold
15
+ from sklearn.pipeline import Pipeline
16
+ from sklearn.preprocessing import QuantileTransformer
17
+ from skops.io import dump
18
+
19
+ client = cc.CAVEclient("minnie65_phase3_v1")
20
+
21
+ out_path = Path("./troglobyte-sandbox/models/")
22
+
23
+ model_name = "local_compartment_classifier_bd_boxes"
24
+
25
+ data_path = Path("./troglobyte-sandbox/data/bounding_box_labels")
26
+
27
+ files = list(data_path.glob("*.csv"))
28
+
29
+ # %%
30
+ label_df = pd.read_csv(out_path / model_name / "labels.csv", index_col=[0, 1])
31
+ label_df = label_df.rename(columns=lambda x: x.replace(".1", ""))
32
+
33
+ # # %%
34
+
35
+ # X_df = wrangler.features_.copy()
36
+ # X_df = X_df.drop(columns=[col for col in X_df.columns if "rep_coord" in col])
37
+
38
+ # %%
39
+
40
+ X_df = pd.read_csv(out_path / model_name / "features.csv", index_col=[0, 1])
41
+
42
+
43
+ # %%
44
+
45
+
46
+ def box_train_test_split(
47
+ train_box_indices, test_box_indices, X_df, label_df, label_column
48
+ ):
49
+ train_label_df = label_df.loc[train_box_indices + 1].droplevel("bbox_id")
50
+ test_label_df = label_df.loc[test_box_indices + 1].droplevel("bbox_id")
51
+
52
+ train_X_df = X_df.loc[train_label_df["root_id"]]
53
+ test_X_df = X_df.loc[test_label_df["root_id"]]
54
+ train_X_df = train_X_df.dropna()
55
+ test_X_df = test_X_df.dropna()
56
+
57
+ train_l2_y = train_X_df.index.get_level_values("object_id").map(
58
+ train_label_df[label_column]
59
+ )
60
+ test_l2_y = test_X_df.index.get_level_values("object_id").map(
61
+ test_label_df[label_column]
62
+ )
63
+
64
+ # TODO do something more fair here w/ evaluation on the uncertains
65
+ train_X_df = train_X_df.loc[train_l2_y.notna()]
66
+ train_l2_y = train_l2_y[train_l2_y.notna()].values.astype(str)
67
+
68
+ test_X_df = test_X_df.loc[test_l2_y.notna()]
69
+ test_l2_y = test_l2_y[test_l2_y.notna()].values.astype(str)
70
+
71
+ return train_X_df, test_X_df, train_l2_y, test_l2_y
72
+
73
+
74
+ def aggregate_votes_by_object(X_df, l2_node_predictions):
75
+ l2_node_predictions = pd.Series(
76
+ index=X_df.index, data=l2_node_predictions, name="label"
77
+ )
78
+
79
+ object_prediction_counts = (
80
+ l2_node_predictions.groupby(level="object_id").value_counts().to_frame()
81
+ )
82
+
83
+ object_n_predictions = object_prediction_counts.groupby("object_id").sum()
84
+
85
+ sufficient_data_index = object_n_predictions.query("count > 3").index
86
+
87
+ object_prediction_counts = object_prediction_counts.loc[sufficient_data_index]
88
+
89
+ object_prediction_probs = object_prediction_counts.unstack(fill_value=0)
90
+ object_prediction_probs = object_prediction_probs.div(
91
+ object_prediction_probs.sum(axis=1), axis=0
92
+ )
93
+
94
+ object_prediction_counts.reset_index(drop=False, inplace=True)
95
+
96
+ max_locs = object_prediction_counts.groupby("object_id")["count"].idxmax()
97
+
98
+ max_predictions = object_prediction_counts.loc[max_locs]
99
+ max_predictions["proportion"] = (
100
+ max_predictions["count"]
101
+ / object_n_predictions.loc[max_predictions["object_id"]]["count"].values
102
+ )
103
+ max_predictions = max_predictions.set_index("object_id")
104
+ return max_predictions, object_prediction_probs
105
+
106
+
107
+ # models to evaluate
108
+ def get_lda(n_classes):
109
+ lda = Pipeline(
110
+ [
111
+ ("transformer", QuantileTransformer(output_distribution="normal")),
112
+ ("lda", LinearDiscriminantAnalysis(n_components=n_classes - 1)),
113
+ ]
114
+ )
115
+ return lda
116
+
117
+
118
+ rf = RandomForestClassifier(n_estimators=500, max_depth=4)
119
+
120
+ box_indices = np.arange(1, 4)
121
+
122
+ rows = []
123
+ for fold, (train_box_indices, test_box_indices) in enumerate(
124
+ KFold(n_splits=3).split(box_indices.reshape(-1, 1))
125
+ ):
126
+ for label_column in ["axon_label", "simple_label"]:
127
+ train_X_df, test_X_df, train_l2_y, test_l2_y = box_train_test_split(
128
+ train_box_indices, test_box_indices, X_df, label_df, label_column
129
+ )
130
+ n_classes = label_df[label_column].nunique()
131
+ models = {"rf": rf, "lda": get_lda(n_classes)}
132
+ for model_type, model in models.items():
133
+ model.fit(train_X_df, train_l2_y)
134
+ train_preds = model.predict(train_X_df)
135
+ test_preds = model.predict(test_X_df)
136
+
137
+ # evaluate at the L2 level
138
+ train_report = classification_report(
139
+ train_l2_y, train_preds, output_dict=True
140
+ )
141
+ rows.append(
142
+ {
143
+ "model": model_type,
144
+ "fold": fold,
145
+ "accuracy": train_report["accuracy"],
146
+ "macro_f1": train_report["macro avg"]["f1-score"],
147
+ "weighted_f1": train_report["weighted avg"]["f1-score"],
148
+ "evaluation": "train",
149
+ "labeling": label_column,
150
+ "level": "level2",
151
+ }
152
+ )
153
+
154
+ test_report = classification_report(test_l2_y, test_preds, output_dict=True)
155
+ rows.append(
156
+ {
157
+ "model": model_type,
158
+ "fold": fold,
159
+ "accuracy": test_report["accuracy"],
160
+ "macro_f1": test_report["macro avg"]["f1-score"],
161
+ "weighted_f1": test_report["weighted avg"]["f1-score"],
162
+ "evaluation": "test",
163
+ "labeling": label_column,
164
+ "level": "level2",
165
+ }
166
+ )
167
+
168
+ # evaluate at the object level
169
+ train_object_predictions, train_object_probs = aggregate_votes_by_object(
170
+ train_X_df, train_preds
171
+ )
172
+ train_object_y = (
173
+ label_df.droplevel(0)
174
+ .loc[train_object_predictions.index, label_column]
175
+ .values.astype(str)
176
+ )
177
+ train_object_report = classification_report(
178
+ train_object_y, train_object_predictions["label"], output_dict=True
179
+ )
180
+ rows.append(
181
+ {
182
+ "model": model_type + "-vote",
183
+ "fold": fold,
184
+ "accuracy": train_object_report["accuracy"],
185
+ "macro_f1": train_object_report["macro avg"]["f1-score"],
186
+ "weighted_f1": train_object_report["weighted avg"]["f1-score"],
187
+ "evaluation": "train",
188
+ "labeling": label_column,
189
+ "level": "root",
190
+ }
191
+ )
192
+
193
+ test_object_predictions, test_object_probs = aggregate_votes_by_object(
194
+ test_X_df, test_preds
195
+ )
196
+ test_object_y = (
197
+ label_df.droplevel(0)
198
+ .loc[test_object_predictions.index, label_column]
199
+ .values.astype(str)
200
+ )
201
+ test_object_report = classification_report(
202
+ test_object_y, test_object_predictions["label"], output_dict=True
203
+ )
204
+ rows.append(
205
+ {
206
+ "model": model_type + "-vote",
207
+ "fold": fold,
208
+ "accuracy": test_object_report["accuracy"],
209
+ "macro_f1": train_object_report["macro avg"]["f1-score"],
210
+ "weighted_f1": train_object_report["weighted avg"]["f1-score"],
211
+ "evaluation": "test",
212
+ "labeling": label_column,
213
+ "level": "root",
214
+ }
215
+ )
216
+
217
+
218
+ # %%
219
+
220
+ evaluation_df = pd.DataFrame(rows)
221
+
222
+ sns.set_context("talk")
223
+
224
+ fig, axs = plt.subplots(2, 3, figsize=(15, 10), constrained_layout=True, sharey="col")
225
+ for i, labeling in enumerate(["simple_label", "axon_label"]):
226
+ for j, metric in enumerate(["accuracy", "weighted_f1", "macro_f1"]):
227
+ ax = axs[i, j]
228
+ show_legend = (i == 0) & (j == 0)
229
+ sns.stripplot(
230
+ data=evaluation_df.query("labeling == @labeling"),
231
+ x="model",
232
+ y=metric,
233
+ hue="evaluation",
234
+ ax=ax,
235
+ legend=show_legend,
236
+ s=10,
237
+ jitter=True,
238
+ )
239
+ ax.spines[["right", "top"]].set_visible(False)
240
+ if j == 1:
241
+ ax.set_title("Labeling: " + labeling)
242
+
243
+
244
+ # %%
245
+ lda = model
246
+ train_X_transformed = lda.transform(train_X_df)
247
+
248
+ # %%
249
+ fig, ax = plt.subplots(1, 1, figsize=(10, 10))
250
+ sns.scatterplot(
251
+ x=train_X_transformed[:, 0],
252
+ y=train_X_transformed[:, 1],
253
+ hue=train_l2_y,
254
+ ax=ax,
255
+ s=10,
256
+ alpha=0.7,
257
+ )
258
+ ax.set(xticks=[], yticks=[], xlabel="LDA1", ylabel="LDA2")
259
+ ax.spines[["right", "top"]].set_visible(False)
260
+ # %%
261
+ fig, ax = plt.subplots(1, 1, figsize=(10, 10))
262
+ sns.scatterplot(
263
+ x=train_X_transformed[:, 0],
264
+ y=train_X_transformed[:, 2],
265
+ hue=train_l2_y,
266
+ ax=ax,
267
+ s=10,
268
+ alpha=0.7,
269
+ )
270
+ ax.set(xticks=[], yticks=[], xlabel="LDA1", ylabel="LDA3")
271
+ ax.spines[["right", "top"]].set_visible(False)
272
+
273
+
274
+ # %%
275
+ final_lda = Pipeline(
276
+ [
277
+ ("transformer", QuantileTransformer(output_distribution="normal")),
278
+ ("lda", LinearDiscriminantAnalysis(n_components=n_classes - 1)),
279
+ ]
280
+ )
281
+
282
+ train_X_df, test_X_df, train_l2_y, test_l2_y = box_train_test_split(
283
+ np.array([0, 1, 2]), np.array([]), X_df, label_df, label_column
284
+ )
285
+
286
+ final_lda.fit(train_X_df, train_l2_y)
287
+
288
+ # %%
289
+
290
+ model_pickle_file = out_path / model_name / f"{model_name}.skops"
291
+ with open(model_pickle_file, mode="bw") as f:
292
+ dump(final_lda, file=f)
293
+
294
+ # %%
295
+ from pathlib import Path
296
+
297
+ from skops import card, hub_utils
298
+
299
+ hub_out_path = Path(
300
+ "troglobyte-sandbox/models/local_compartment_classifier_bd_boxes/hub"
301
+ )
302
+ if not hub_out_path.exists():
303
+ hub_utils.init(
304
+ model=model_pickle_file,
305
+ requirements=["scikit-learn", "caveclient"],
306
+ dst=hub_out_path,
307
+ task="tabular-classification",
308
+ data=train_X_df,
309
+ )
310
+
311
+ hub_utils.add_files(__file__, dst=hub_out_path, exist_ok=True)
312
+
313
+ model_card = card.Card(model, metadata=card.metadata_from_config(hub_out_path))
314
+
315
+ model_card.metadata.license = "mit"
316
+
317
+ model_description = (
318
+ "This is a model trained to classify pieces of neuron as axon, dendrite, soma, or"
319
+ "glia, "
320
+ "based only on their local shape and synapse features."
321
+ "The model is a linear discriminant classifier which was trained on compartment "
322
+ "labels generated by Bethanny Danskin for 3 6x6x6 um boxes in the Minnie65 Phase3 "
323
+ "dataset."
324
+ )
325
+
326
+ model_card_authors = "bdpedigo"
327
+
328
+ model_card.add(
329
+ model_card_authors=model_card_authors,
330
+ model_description=model_description,
331
+ )
332
+
333
+ model_card.save(hub_out_path / "README.md")
334
+
335
+ hub_utils.push(
336
+ repo_id=f"bdpedigo/{model_name}",
337
+ source=hub_out_path,
338
+ create_remote=False,
339
+ private=False,
340
+ )
341
+
342
+ # %%
343
+
344
+ syn_features = [col for col in X_df.columns if "syn" in col]
345
+ train_X_df_no_syn = train_X_df.drop(columns=syn_features)
346
+
347
+ final_lda_no_syn = Pipeline(
348
+ [
349
+ ("transformer", QuantileTransformer(output_distribution="normal")),
350
+ ("lda", LinearDiscriminantAnalysis(n_components=n_classes - 1)),
351
+ ]
352
+ )
353
+
354
+ final_lda_no_syn.fit(train_X_df_no_syn, train_l2_y)
355
+
356
+ print(classification_report(train_l2_y, final_lda_no_syn.predict(train_X_df_no_syn)))
357
+
358
+ with open(out_path / model_name / f"{model_name}_no_syn.skops", mode="bw") as f:
359
+ dump(final_lda_no_syn, file=f)