File size: 1,352 Bytes
531604b
1e78889
 
 
 
 
 
 
 
 
 
531604b
1e78889
531604b
1e78889
531604b
1e78889
531604b
1e78889
 
 
 
531604b
1e78889
531604b
1e78889
 
 
 
531604b
1e78889
 
 
531604b
1e78889
 
531604b
1e78889
 
531604b
1e78889
 
 
 
 
531604b
1e78889
 
 
531604b
1e78889
 
 
531604b
1e78889
531604b
1e78889
 
 
8e0b454
1e78889
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
---
license: apache-2.0
datasets:
- ILSVRC/imagenet-1k
tags:
- mlx
- mlx-image
- vision
- image-classification
library_name: mlx-image
---

# regnet_y_800mf

A RegNetY-800MF image classification model. Pretrained in ImageNet by torchvision contributors (see ImageNet1K-V2 weight details https://github.com/pytorch/vision/issues/3995#new-recipe).

Disclaimer: This is a porting of the torch model weights to Apple MLX Framework.

## How to use
```bash
pip install mlx-image
```

Here is how to use this model for image classification:

```python
from mlxim.model import create_model
from mlxim.io import read_rgb
from mlxim.transform import ImageNetTransform

transform = ImageNetTransform(train=False, img_size=224)
x = transform(read_rgb("cat.png"))
x = mx.expand_dims(x, 0)

model = create_model("regnet_y_800mf")
model.eval()

logits = model(x)
```

You can also use the embeds from layer before head:
```python
from mlxim.model import create_model
from mlxim.io import read_rgb
from mlxim.transform import ImageNetTransform

transform = ImageNetTransform(train=False, img_size=224)
x = transform(read_rgb("cat.png"))
x = mx.expand_dims(x, 0)

# first option
model = create_model("regnet_y_800mf", num_classes=0)
model.eval()

embeds = model(x)

# second option
model = create_model("regnet_y_800mf")
model.eval()

embeds = model.get_features(x)
```