Samuel Stevens commited on
Commit
2edabd1
·
1 Parent(s): 3e841d9

update to use gradients

Browse files
Files changed (2) hide show
  1. app.py +18 -13
  2. requirements.txt +53 -305
app.py CHANGED
@@ -10,6 +10,7 @@ import random
10
  import beartype
11
  import einops.layers.torch
12
  import gradio as gr
 
13
  import numpy as np
14
  import open_clip
15
  import requests
@@ -18,7 +19,7 @@ import torch
18
  from jaxtyping import Float, jaxtyped
19
  from PIL import Image, ImageDraw
20
  from torch import Tensor
21
- from torchvision.transforms import v2
22
 
23
  log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s"
24
  logging.basicConfig(level=logging.INFO, format=log_format)
@@ -29,10 +30,10 @@ logger = logging.getLogger("app.py")
29
  # Global Constants #
30
  ####################
31
 
32
- DEBUG = True
33
  """Whether we are debugging."""
34
 
35
- n_sae_latents = 3
36
  """Number of SAE latents to show."""
37
 
38
  n_sae_examples = 4
@@ -54,6 +55,7 @@ CWD = pathlib.Path(__file__).parent
54
 
55
  r2_url = "https://pub-289086e849214430853bc87bd8964988.r2.dev/"
56
 
 
57
 
58
  logger.info("Set global constants.")
59
 
@@ -117,7 +119,6 @@ def make_img(
117
  (resize_w_px + crop_w_px) // 2,
118
  (resize_h_px + crop_h_px) // 2,
119
  )
120
-
121
  img = img.resize(resize_size_px).crop(crop_coords_px)
122
  img = add_highlights(img, patches.numpy(), upper=upper, opacity=0.5)
123
  return img
@@ -207,10 +208,10 @@ logger.info("Loaded SAE.")
207
  # Datasets #
208
  ############
209
 
210
- human_transform = v2.Compose([
211
- v2.Resize((512, 512), interpolation=v2.InterpolationMode.NEAREST),
212
- v2.CenterCrop((448, 448)),
213
- v2.ToImage(),
214
  einops.layers.torch.Rearrange("channels width height -> width height channels"),
215
  ])
216
 
@@ -258,7 +259,10 @@ mask = mask & (sparsity < max_frequency)
258
  def get_image(image_i: int) -> list[Image.Image | int]:
259
  image = get_dataset_img(image_i)
260
  image = human_transform(image)
261
- return [Image.fromarray(image.numpy()), image_labels[image_i]]
 
 
 
262
 
263
 
264
  @beartype.beartype
@@ -268,7 +272,7 @@ def get_random_class_image(cls: int) -> Image.Image:
268
 
269
  image = get_dataset_img(i)
270
  image = human_transform(image)
271
- return Image.fromarray(image.numpy())
272
 
273
 
274
  @torch.inference_mode
@@ -430,8 +434,9 @@ def add_highlights(
430
  overlay = Image.new("RGBA", img.size, (0, 0, 0, 0))
431
  draw = ImageDraw.Draw(overlay)
432
 
433
- # Using semi-transparent red (255, 0, 0, alpha)
434
- for p, val in enumerate(patches):
 
435
  assert upper is not None
436
  val /= upper + 1e-9
437
  x_np, y_np = p % iw_np, p // ih_np
@@ -440,7 +445,7 @@ def add_highlights(
440
  (x_np * pw_px, y_np * ph_px),
441
  (x_np * pw_px + pw_px, y_np * ph_px + ph_px),
442
  ],
443
- fill=(int(val * 256), 0, 0, int(opacity * val * 256)),
444
  )
445
 
446
  # Composite the original image and the overlay
 
10
  import beartype
11
  import einops.layers.torch
12
  import gradio as gr
13
+ import matplotlib
14
  import numpy as np
15
  import open_clip
16
  import requests
 
19
  from jaxtyping import Float, jaxtyped
20
  from PIL import Image, ImageDraw
21
  from torch import Tensor
22
+ from torchvision import transforms
23
 
24
  log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s"
25
  logging.basicConfig(level=logging.INFO, format=log_format)
 
30
  # Global Constants #
31
  ####################
32
 
33
+ DEBUG = False
34
  """Whether we are debugging."""
35
 
36
+ n_sae_latents = 5
37
  """Number of SAE latents to show."""
38
 
39
  n_sae_examples = 4
 
55
 
56
  r2_url = "https://pub-289086e849214430853bc87bd8964988.r2.dev/"
57
 
58
+ colormap = matplotlib.colormaps.get_cmap("plasma")
59
 
60
  logger.info("Set global constants.")
61
 
 
119
  (resize_w_px + crop_w_px) // 2,
120
  (resize_h_px + crop_h_px) // 2,
121
  )
 
122
  img = img.resize(resize_size_px).crop(crop_coords_px)
123
  img = add_highlights(img, patches.numpy(), upper=upper, opacity=0.5)
124
  return img
 
208
  # Datasets #
209
  ############
210
 
211
+ human_transform = transforms.Compose([
212
+ transforms.Resize((448,), interpolation=transforms.InterpolationMode.BICUBIC),
213
+ transforms.CenterCrop((448, 448)),
214
+ transforms.ToTensor(),
215
  einops.layers.torch.Rearrange("channels width height -> width height channels"),
216
  ])
217
 
 
259
  def get_image(image_i: int) -> list[Image.Image | int]:
260
  image = get_dataset_img(image_i)
261
  image = human_transform(image)
262
+ return [
263
+ Image.fromarray((image * 255).to(torch.uint8).numpy()),
264
+ image_labels[image_i],
265
+ ]
266
 
267
 
268
  @beartype.beartype
 
272
 
273
  image = get_dataset_img(i)
274
  image = human_transform(image)
275
+ return Image.fromarray((image * 255).to(torch.uint8).numpy())
276
 
277
 
278
  @torch.inference_mode
 
434
  overlay = Image.new("RGBA", img.size, (0, 0, 0, 0))
435
  draw = ImageDraw.Draw(overlay)
436
 
437
+ colors = (colormap(patches / (upper + 1e-9))[:, :3] * 255).astype(np.uint8)
438
+
439
+ for p, (val, color) in enumerate(zip(patches, colors)):
440
  assert upper is not None
441
  val /= upper + 1e-9
442
  x_np, y_np = p % iw_np, p // ih_np
 
445
  (x_np * pw_px, y_np * ph_px),
446
  (x_np * pw_px + pw_px, y_np * ph_px + ph_px),
447
  ],
448
+ fill=(*color, int(opacity * val * 255)),
449
  )
450
 
451
  # Composite the original image and the overlay
requirements.txt CHANGED
@@ -2,9 +2,9 @@
2
  # uv pip compile pyproject.toml
3
  aiofiles==23.2.1
4
  # via gradio
5
- aiohappyeyeballs==2.4.4
6
  # via aiohttp
7
- aiohttp==3.11.11
8
  # via
9
  # datasets
10
  # fsspec
@@ -18,43 +18,22 @@ anyio==4.8.0
18
  # via
19
  # gradio
20
  # httpx
21
- # jupyter-server
22
  # starlette
23
- argon2-cffi==23.1.0
24
- # via jupyter-server
25
- argon2-cffi-bindings==21.2.0
26
- # via argon2-cffi
27
- arrow==1.3.0
28
- # via isoduration
29
- asttokens==3.0.0
30
- # via stack-data
31
- async-lru==2.0.4
32
- # via jupyterlab
33
- attrs==24.3.0
34
  # via
35
  # aiohttp
36
  # jsonschema
37
  # referencing
38
- babel==2.16.0
39
- # via jupyterlab-server
40
  beartype==0.19.0
41
  # via
42
  # saev-image-classification (pyproject.toml)
43
  # saev
44
- beautifulsoup4==4.12.3
45
- # via nbconvert
46
- bleach==6.2.0
47
- # via nbconvert
48
- braceexpand==0.1.7
49
- # via webdataset
50
- certifi==2024.12.14
51
  # via
52
  # httpcore
53
  # httpx
54
  # requests
55
- # sentry-sdk
56
- cffi==1.17.1
57
- # via argon2-cffi-bindings
58
  charset-normalizer==3.4.1
59
  # via requests
60
  click==8.1.8
@@ -62,60 +41,42 @@ click==8.1.8
62
  # marimo
63
  # typer
64
  # uvicorn
65
- # wandb
66
- cloudpickle==3.1.0
67
  # via submitit
68
- comm==0.2.2
69
- # via ipykernel
70
  contourpy==1.3.1
71
  # via matplotlib
72
  cycler==0.12.1
73
  # via matplotlib
74
- datasets==3.2.0
75
  # via saev
76
- debugpy==1.8.11
77
- # via ipykernel
78
- decorator==5.1.1
79
- # via ipython
80
- defusedxml==0.7.1
81
- # via nbconvert
82
  dill==0.3.8
83
  # via
84
  # datasets
85
  # multiprocess
86
- docker-pycreds==0.4.0
87
- # via wandb
88
  docstring-parser==0.16
89
  # via tyro
90
  docutils==0.21.2
91
  # via marimo
92
- einops==0.8.0
93
  # via
94
  # saev-image-classification (pyproject.toml)
95
  # saev
96
- executing==2.1.0
97
- # via stack-data
98
- fastapi==0.115.6
99
  # via gradio
100
- fastjsonschema==2.21.1
101
- # via nbformat
102
  ffmpy==0.5.0
103
  # via gradio
104
- filelock==3.16.1
105
  # via
106
  # datasets
107
  # huggingface-hub
108
  # torch
109
- # triton
110
- fonttools==4.55.3
111
  # via matplotlib
112
- fqdn==1.5.1
113
- # via jsonschema
114
  frozenlist==1.5.0
115
  # via
116
  # aiohttp
117
  # aiosignal
118
- fsspec==2024.9.0
119
  # via
120
  # datasets
121
  # gradio-client
@@ -123,13 +84,9 @@ fsspec==2024.9.0
123
  # torch
124
  ftfy==6.3.1
125
  # via open-clip-torch
126
- gitdb==4.0.12
127
- # via gitpython
128
- gitpython==3.1.44
129
- # via wandb
130
- gradio==5.10.0
131
  # via saev-image-classification (pyproject.toml)
132
- gradio-client==1.5.3
133
  # via gradio
134
  h11==0.14.0
135
  # via
@@ -141,9 +98,8 @@ httpx==0.28.1
141
  # via
142
  # gradio
143
  # gradio-client
144
- # jupyterlab
145
  # safehttpx
146
- huggingface-hub==0.27.1
147
  # via
148
  # datasets
149
  # gradio
@@ -154,90 +110,32 @@ idna==3.10
154
  # via
155
  # anyio
156
  # httpx
157
- # jsonschema
158
  # requests
159
  # yarl
160
- ipykernel==6.29.5
161
- # via jupyterlab
162
- ipython==8.31.0
163
- # via ipykernel
164
- isoduration==20.11.0
165
- # via jsonschema
166
  itsdangerous==2.2.0
167
  # via marimo
168
- jaxtyping==0.2.36
169
  # via
170
  # saev-image-classification (pyproject.toml)
171
  # saev
172
  jedi==0.19.2
173
- # via
174
- # ipython
175
- # marimo
176
  jinja2==3.1.5
177
  # via
178
  # altair
179
  # gradio
180
- # jupyter-server
181
- # jupyterlab
182
- # jupyterlab-server
183
- # nbconvert
184
  # torch
185
- joblib==1.4.2
186
- # via scikit-learn
187
- json5==0.10.0
188
- # via jupyterlab-server
189
- jsonpointer==3.0.0
190
- # via jsonschema
191
  jsonschema==4.23.0
192
- # via
193
- # altair
194
- # jupyter-events
195
- # jupyterlab-server
196
- # nbformat
197
  jsonschema-specifications==2024.10.1
198
  # via jsonschema
199
- jupyter-client==8.6.3
200
- # via
201
- # ipykernel
202
- # jupyter-server
203
- # nbclient
204
- jupyter-core==5.7.2
205
- # via
206
- # ipykernel
207
- # jupyter-client
208
- # jupyter-server
209
- # jupyterlab
210
- # nbclient
211
- # nbconvert
212
- # nbformat
213
- jupyter-events==0.11.0
214
- # via jupyter-server
215
- jupyter-lsp==2.2.5
216
- # via jupyterlab
217
- jupyter-server==2.15.0
218
- # via
219
- # jupyter-lsp
220
- # jupyterlab
221
- # jupyterlab-server
222
- # notebook-shim
223
- jupyter-server-terminals==0.5.3
224
- # via jupyter-server
225
- jupyterlab==4.3.4
226
- # via saev
227
- jupyterlab-pygments==0.3.0
228
- # via nbconvert
229
- jupyterlab-server==2.27.3
230
- # via jupyterlab
231
  kiwisolver==1.4.8
232
  # via matplotlib
233
- mako==1.3.8
234
- # via pdoc3
235
- marimo==0.10.9
236
  # via saev
237
  markdown==3.7
238
  # via
239
  # marimo
240
- # pdoc3
241
  # pymdown-extensions
242
  markdown-it-py==3.0.0
243
  # via rich
@@ -245,18 +143,10 @@ markupsafe==2.1.5
245
  # via
246
  # gradio
247
  # jinja2
248
- # mako
249
- # nbconvert
250
  matplotlib==3.10.0
251
  # via saev
252
- matplotlib-inline==0.1.7
253
- # via
254
- # ipykernel
255
- # ipython
256
  mdurl==0.1.2
257
  # via markdown-it-py
258
- mistune==3.1.0
259
- # via nbconvert
260
  mpmath==1.3.0
261
  # via sympy
262
  multidict==6.1.0
@@ -265,26 +155,13 @@ multidict==6.1.0
265
  # yarl
266
  multiprocess==0.70.16
267
  # via datasets
268
- narwhals==1.21.1
269
  # via
270
  # altair
271
  # marimo
272
- nbclient==0.10.2
273
- # via nbconvert
274
- nbconvert==7.16.5
275
- # via jupyter-server
276
- nbformat==5.10.4
277
- # via
278
- # jupyter-server
279
- # nbclient
280
- # nbconvert
281
- nest-asyncio==1.6.0
282
- # via ipykernel
283
  networkx==3.4.2
284
  # via torch
285
- notebook-shim==0.2.4
286
- # via jupyterlab
287
- numpy==2.2.1
288
  # via
289
  # saev-image-classification (pyproject.toml)
290
  # contourpy
@@ -292,10 +169,7 @@ numpy==2.2.1
292
  # gradio
293
  # matplotlib
294
  # pandas
295
- # scikit-learn
296
- # scipy
297
  # torchvision
298
- # webdataset
299
  nvidia-cublas-cu12==12.4.5.8
300
  # via
301
  # nvidia-cudnn-cu12
@@ -319,6 +193,8 @@ nvidia-cusparse-cu12==12.3.1.170
319
  # via
320
  # nvidia-cusolver-cu12
321
  # torch
 
 
322
  nvidia-nccl-cu12==2.21.5
323
  # via torch
324
  nvidia-nvjitlink-cu12==12.4.127
@@ -332,10 +208,8 @@ open-clip-torch==2.30.0
332
  # via
333
  # saev-image-classification (pyproject.toml)
334
  # saev
335
- orjson==3.10.13
336
  # via gradio
337
- overrides==7.7.0
338
- # via jupyter-server
339
  packaging==24.2
340
  # via
341
  # altair
@@ -343,130 +217,74 @@ packaging==24.2
343
  # gradio
344
  # gradio-client
345
  # huggingface-hub
346
- # ipykernel
347
- # jupyter-server
348
- # jupyterlab
349
- # jupyterlab-server
350
  # marimo
351
  # matplotlib
352
- # nbconvert
353
  pandas==2.2.3
354
  # via
355
  # datasets
356
  # gradio
357
- pandocfilters==1.5.1
358
- # via nbconvert
359
  parso==0.8.4
360
  # via jedi
361
- pdoc3==0.11.5
362
- # via saev
363
- pexpect==4.9.0
364
- # via ipython
365
  pillow==11.1.0
366
  # via
367
  # gradio
368
  # matplotlib
369
  # saev
370
  # torchvision
371
- platformdirs==4.3.6
372
- # via
373
- # jupyter-core
374
- # wandb
375
- polars==1.19.0
376
  # via saev
377
- prometheus-client==0.21.1
378
- # via jupyter-server
379
- prompt-toolkit==3.0.48
380
- # via ipython
381
  propcache==0.2.1
382
  # via
383
  # aiohttp
384
  # yarl
385
- protobuf==5.29.2
386
- # via wandb
387
- psutil==6.1.1
388
- # via
389
- # ipykernel
390
- # marimo
391
- # wandb
392
- ptyprocess==0.7.0
393
- # via
394
- # pexpect
395
- # terminado
396
- pure-eval==0.2.3
397
- # via stack-data
398
- pyarrow==18.1.0
399
  # via datasets
400
- pycparser==2.22
401
- # via cffi
402
- pydantic==2.10.4
403
  # via
404
  # fastapi
405
  # gradio
406
- # wandb
407
  pydantic-core==2.27.2
408
  # via pydantic
409
  pydub==0.25.1
410
  # via gradio
411
  pygments==2.19.1
412
  # via
413
- # ipython
414
  # marimo
415
- # nbconvert
416
  # rich
417
- pymdown-extensions==10.14
418
  # via marimo
419
  pyparsing==3.2.1
420
  # via matplotlib
421
  python-dateutil==2.9.0.post0
422
  # via
423
- # arrow
424
- # jupyter-client
425
  # matplotlib
426
  # pandas
427
- python-json-logger==3.2.1
428
- # via jupyter-events
429
  python-multipart==0.0.20
430
  # via gradio
431
- pytz==2024.2
432
  # via pandas
433
  pyyaml==6.0.2
434
  # via
435
  # datasets
436
  # gradio
437
  # huggingface-hub
438
- # jupyter-events
439
  # marimo
440
  # pymdown-extensions
441
  # timm
442
- # wandb
443
- # webdataset
444
- pyzmq==26.2.0
445
- # via
446
- # ipykernel
447
- # jupyter-client
448
- # jupyter-server
449
- referencing==0.35.1
450
  # via
451
  # jsonschema
452
  # jsonschema-specifications
453
- # jupyter-events
454
  regex==2024.11.6
455
  # via open-clip-torch
456
  requests==2.32.3
457
  # via
458
  # datasets
459
  # huggingface-hub
460
- # jupyterlab-server
461
- # wandb
462
- rfc3339-validator==0.1.4
463
- # via
464
- # jsonschema
465
- # jupyter-events
466
- rfc3986-validator==0.1.1
467
- # via
468
- # jsonschema
469
- # jupyter-events
470
  rich==13.9.4
471
  # via
472
  # typer
@@ -475,53 +293,31 @@ rpds-py==0.22.3
475
  # via
476
  # jsonschema
477
  # referencing
478
- ruff==0.8.6
479
  # via
480
  # gradio
481
  # marimo
482
- saev @ git+https://github.com/samuelstevens/saev@c723ff95462736d907b3c1891d3a1496eb07318f
483
  # via saev-image-classification (pyproject.toml)
484
  safehttpx==0.1.6
485
  # via gradio
486
- safetensors==0.5.1
487
  # via
488
  # open-clip-torch
489
  # timm
490
- scikit-learn==1.6.0
491
- # via saev
492
- scipy==1.15.0
493
- # via scikit-learn
494
  semantic-version==2.10.0
495
  # via gradio
496
- send2trash==1.8.3
497
- # via jupyter-server
498
- sentry-sdk==2.19.2
499
- # via wandb
500
- setproctitle==1.3.4
501
- # via wandb
502
- setuptools==75.7.0
503
- # via
504
- # jupyterlab
505
- # torch
506
- # wandb
507
  shellingham==1.5.4
508
  # via typer
509
  shtab==1.7.1
510
  # via tyro
511
  six==1.17.0
512
- # via
513
- # docker-pycreds
514
- # python-dateutil
515
- # rfc3339-validator
516
- smmap==5.0.2
517
- # via gitdb
518
  sniffio==1.3.1
519
  # via anyio
520
- soupsieve==2.6
521
- # via beautifulsoup4
522
- stack-data==0.6.3
523
- # via ipython
524
- starlette==0.41.3
525
  # via
526
  # fastapi
527
  # gradio
@@ -530,67 +326,36 @@ submitit==1.5.2
530
  # via saev
531
  sympy==1.13.1
532
  # via torch
533
- terminado==0.18.1
534
- # via
535
- # jupyter-server
536
- # jupyter-server-terminals
537
- threadpoolctl==3.5.0
538
- # via scikit-learn
539
- timm==1.0.12
540
  # via open-clip-torch
541
- tinycss2==1.4.0
542
- # via bleach
543
  tomlkit==0.13.2
544
  # via
545
  # gradio
546
  # marimo
547
- torch==2.5.1
548
  # via
549
  # saev-image-classification (pyproject.toml)
550
  # open-clip-torch
551
  # saev
552
  # timm
553
  # torchvision
554
- torchvision==0.20.1
555
  # via
556
  # saev-image-classification (pyproject.toml)
557
  # open-clip-torch
558
  # timm
559
- tornado==6.4.2
560
- # via
561
- # ipykernel
562
- # jupyter-client
563
- # jupyter-server
564
- # jupyterlab
565
- # terminado
566
  tqdm==4.67.1
567
  # via
568
  # datasets
569
  # huggingface-hub
570
  # open-clip-torch
571
  # saev
572
- traitlets==5.14.3
573
- # via
574
- # comm
575
- # ipykernel
576
- # ipython
577
- # jupyter-client
578
- # jupyter-core
579
- # jupyter-events
580
- # jupyter-server
581
- # jupyterlab
582
- # matplotlib-inline
583
- # nbclient
584
- # nbconvert
585
- # nbformat
586
- triton==3.1.0
587
  # via torch
588
  typeguard==4.4.1
589
  # via tyro
590
  typer==0.15.1
591
  # via gradio
592
- types-python-dateutil==2.9.0.20241206
593
- # via arrow
594
  typing-extensions==4.12.2
595
  # via
596
  # altair
@@ -601,44 +366,27 @@ typing-extensions==4.12.2
601
  # huggingface-hub
602
  # pydantic
603
  # pydantic-core
 
604
  # submitit
605
  # torch
606
  # typeguard
607
  # typer
608
  # tyro
609
- tyro==0.9.6
610
  # via saev
611
- tzdata==2024.2
612
  # via pandas
613
- uri-template==1.3.0
614
- # via jsonschema
615
  urllib3==2.3.0
616
- # via
617
- # requests
618
- # sentry-sdk
619
  uvicorn==0.34.0
620
  # via
621
  # gradio
622
  # marimo
623
- vl-convert-python==1.7.0
624
- # via saev
625
- wandb==0.19.1
626
- # via saev
627
  wcwidth==0.2.13
628
- # via
629
- # ftfy
630
- # prompt-toolkit
631
- webcolors==24.11.1
632
- # via jsonschema
633
- webdataset==0.2.100
634
- # via saev
635
- webencodings==0.5.1
636
- # via
637
- # bleach
638
- # tinycss2
639
- websocket-client==1.8.0
640
- # via jupyter-server
641
- websockets==14.1
642
  # via
643
  # gradio-client
644
  # marimo
 
2
  # uv pip compile pyproject.toml
3
  aiofiles==23.2.1
4
  # via gradio
5
+ aiohappyeyeballs==2.4.6
6
  # via aiohttp
7
+ aiohttp==3.11.12
8
  # via
9
  # datasets
10
  # fsspec
 
18
  # via
19
  # gradio
20
  # httpx
21
+ # pycrdt
22
  # starlette
23
+ attrs==25.1.0
 
 
 
 
 
 
 
 
 
 
24
  # via
25
  # aiohttp
26
  # jsonschema
27
  # referencing
 
 
28
  beartype==0.19.0
29
  # via
30
  # saev-image-classification (pyproject.toml)
31
  # saev
32
+ certifi==2025.1.31
 
 
 
 
 
 
33
  # via
34
  # httpcore
35
  # httpx
36
  # requests
 
 
 
37
  charset-normalizer==3.4.1
38
  # via requests
39
  click==8.1.8
 
41
  # marimo
42
  # typer
43
  # uvicorn
44
+ cloudpickle==3.1.1
 
45
  # via submitit
 
 
46
  contourpy==1.3.1
47
  # via matplotlib
48
  cycler==0.12.1
49
  # via matplotlib
50
+ datasets==3.3.0
51
  # via saev
 
 
 
 
 
 
52
  dill==0.3.8
53
  # via
54
  # datasets
55
  # multiprocess
 
 
56
  docstring-parser==0.16
57
  # via tyro
58
  docutils==0.21.2
59
  # via marimo
60
+ einops==0.8.1
61
  # via
62
  # saev-image-classification (pyproject.toml)
63
  # saev
64
+ fastapi==0.115.8
 
 
65
  # via gradio
 
 
66
  ffmpy==0.5.0
67
  # via gradio
68
+ filelock==3.17.0
69
  # via
70
  # datasets
71
  # huggingface-hub
72
  # torch
73
+ fonttools==4.56.0
 
74
  # via matplotlib
 
 
75
  frozenlist==1.5.0
76
  # via
77
  # aiohttp
78
  # aiosignal
79
+ fsspec==2024.12.0
80
  # via
81
  # datasets
82
  # gradio-client
 
84
  # torch
85
  ftfy==6.3.1
86
  # via open-clip-torch
87
+ gradio==5.16.0
 
 
 
 
88
  # via saev-image-classification (pyproject.toml)
89
+ gradio-client==1.7.0
90
  # via gradio
91
  h11==0.14.0
92
  # via
 
98
  # via
99
  # gradio
100
  # gradio-client
 
101
  # safehttpx
102
+ huggingface-hub==0.28.1
103
  # via
104
  # datasets
105
  # gradio
 
110
  # via
111
  # anyio
112
  # httpx
 
113
  # requests
114
  # yarl
 
 
 
 
 
 
115
  itsdangerous==2.2.0
116
  # via marimo
117
+ jaxtyping==0.2.38
118
  # via
119
  # saev-image-classification (pyproject.toml)
120
  # saev
121
  jedi==0.19.2
122
+ # via marimo
 
 
123
  jinja2==3.1.5
124
  # via
125
  # altair
126
  # gradio
 
 
 
 
127
  # torch
 
 
 
 
 
 
128
  jsonschema==4.23.0
129
+ # via altair
 
 
 
 
130
  jsonschema-specifications==2024.10.1
131
  # via jsonschema
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  kiwisolver==1.4.8
133
  # via matplotlib
134
+ marimo==0.11.5
 
 
135
  # via saev
136
  markdown==3.7
137
  # via
138
  # marimo
 
139
  # pymdown-extensions
140
  markdown-it-py==3.0.0
141
  # via rich
 
143
  # via
144
  # gradio
145
  # jinja2
 
 
146
  matplotlib==3.10.0
147
  # via saev
 
 
 
 
148
  mdurl==0.1.2
149
  # via markdown-it-py
 
 
150
  mpmath==1.3.0
151
  # via sympy
152
  multidict==6.1.0
 
155
  # yarl
156
  multiprocess==0.70.16
157
  # via datasets
158
+ narwhals==1.26.0
159
  # via
160
  # altair
161
  # marimo
 
 
 
 
 
 
 
 
 
 
 
162
  networkx==3.4.2
163
  # via torch
164
+ numpy==2.2.3
 
 
165
  # via
166
  # saev-image-classification (pyproject.toml)
167
  # contourpy
 
169
  # gradio
170
  # matplotlib
171
  # pandas
 
 
172
  # torchvision
 
173
  nvidia-cublas-cu12==12.4.5.8
174
  # via
175
  # nvidia-cudnn-cu12
 
193
  # via
194
  # nvidia-cusolver-cu12
195
  # torch
196
+ nvidia-cusparselt-cu12==0.6.2
197
+ # via torch
198
  nvidia-nccl-cu12==2.21.5
199
  # via torch
200
  nvidia-nvjitlink-cu12==12.4.127
 
208
  # via
209
  # saev-image-classification (pyproject.toml)
210
  # saev
211
+ orjson==3.10.15
212
  # via gradio
 
 
213
  packaging==24.2
214
  # via
215
  # altair
 
217
  # gradio
218
  # gradio-client
219
  # huggingface-hub
 
 
 
 
220
  # marimo
221
  # matplotlib
 
222
  pandas==2.2.3
223
  # via
224
  # datasets
225
  # gradio
 
 
226
  parso==0.8.4
227
  # via jedi
 
 
 
 
228
  pillow==11.1.0
229
  # via
230
  # gradio
231
  # matplotlib
232
  # saev
233
  # torchvision
234
+ polars==1.22.0
 
 
 
 
235
  # via saev
 
 
 
 
236
  propcache==0.2.1
237
  # via
238
  # aiohttp
239
  # yarl
240
+ psutil==7.0.0
241
+ # via marimo
242
+ pyarrow==19.0.0
 
 
 
 
 
 
 
 
 
 
 
243
  # via datasets
244
+ pycrdt==0.11.1
245
+ # via marimo
246
+ pydantic==2.10.6
247
  # via
248
  # fastapi
249
  # gradio
 
250
  pydantic-core==2.27.2
251
  # via pydantic
252
  pydub==0.25.1
253
  # via gradio
254
  pygments==2.19.1
255
  # via
 
256
  # marimo
 
257
  # rich
258
+ pymdown-extensions==10.14.3
259
  # via marimo
260
  pyparsing==3.2.1
261
  # via matplotlib
262
  python-dateutil==2.9.0.post0
263
  # via
 
 
264
  # matplotlib
265
  # pandas
 
 
266
  python-multipart==0.0.20
267
  # via gradio
268
+ pytz==2025.1
269
  # via pandas
270
  pyyaml==6.0.2
271
  # via
272
  # datasets
273
  # gradio
274
  # huggingface-hub
 
275
  # marimo
276
  # pymdown-extensions
277
  # timm
278
+ referencing==0.36.2
 
 
 
 
 
 
 
279
  # via
280
  # jsonschema
281
  # jsonschema-specifications
 
282
  regex==2024.11.6
283
  # via open-clip-torch
284
  requests==2.32.3
285
  # via
286
  # datasets
287
  # huggingface-hub
 
 
 
 
 
 
 
 
 
 
288
  rich==13.9.4
289
  # via
290
  # typer
 
293
  # via
294
  # jsonschema
295
  # referencing
296
+ ruff==0.9.6
297
  # via
298
  # gradio
299
  # marimo
300
+ saev @ git+https://github.com/samuelstevens/saev@83a442210c81b09e71e68e63ff7e5aadb84e0e87
301
  # via saev-image-classification (pyproject.toml)
302
  safehttpx==0.1.6
303
  # via gradio
304
+ safetensors==0.5.2
305
  # via
306
  # open-clip-torch
307
  # timm
 
 
 
 
308
  semantic-version==2.10.0
309
  # via gradio
310
+ setuptools==75.8.0
311
+ # via torch
 
 
 
 
 
 
 
 
 
312
  shellingham==1.5.4
313
  # via typer
314
  shtab==1.7.1
315
  # via tyro
316
  six==1.17.0
317
+ # via python-dateutil
 
 
 
 
 
318
  sniffio==1.3.1
319
  # via anyio
320
+ starlette==0.45.3
 
 
 
 
321
  # via
322
  # fastapi
323
  # gradio
 
326
  # via saev
327
  sympy==1.13.1
328
  # via torch
329
+ timm==1.0.14
 
 
 
 
 
 
330
  # via open-clip-torch
 
 
331
  tomlkit==0.13.2
332
  # via
333
  # gradio
334
  # marimo
335
+ torch==2.6.0
336
  # via
337
  # saev-image-classification (pyproject.toml)
338
  # open-clip-torch
339
  # saev
340
  # timm
341
  # torchvision
342
+ torchvision==0.21.0
343
  # via
344
  # saev-image-classification (pyproject.toml)
345
  # open-clip-torch
346
  # timm
 
 
 
 
 
 
 
347
  tqdm==4.67.1
348
  # via
349
  # datasets
350
  # huggingface-hub
351
  # open-clip-torch
352
  # saev
353
+ triton==3.2.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
354
  # via torch
355
  typeguard==4.4.1
356
  # via tyro
357
  typer==0.15.1
358
  # via gradio
 
 
359
  typing-extensions==4.12.2
360
  # via
361
  # altair
 
366
  # huggingface-hub
367
  # pydantic
368
  # pydantic-core
369
+ # referencing
370
  # submitit
371
  # torch
372
  # typeguard
373
  # typer
374
  # tyro
375
+ tyro==0.9.14
376
  # via saev
377
+ tzdata==2025.1
378
  # via pandas
 
 
379
  urllib3==2.3.0
380
+ # via requests
 
 
381
  uvicorn==0.34.0
382
  # via
383
  # gradio
384
  # marimo
385
+ wadler-lindig==0.1.3
386
+ # via jaxtyping
 
 
387
  wcwidth==0.2.13
388
+ # via ftfy
389
+ websockets==14.2
 
 
 
 
 
 
 
 
 
 
 
 
390
  # via
391
  # gradio-client
392
  # marimo