jxtc commited on
Commit
798944b
·
verified ·
1 Parent(s): 4812cdb

docs: Update README

Browse files
Files changed (1) hide show
  1. README.md +63 -3
README.md CHANGED
@@ -1,3 +1,63 @@
1
- ---
2
- license: bsd-3-clause
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: bsd-3-clause
3
+ base_model:
4
+ - microsoft/resnet-50
5
+ pipeline_tag: image-feature-extraction
6
+ ---
7
+
8
+ # ResNet-50 Embeddings Only
9
+
10
+ This is a modified version of a standard ResNet-50 architecture, where the final, fully connected layer that does the classification, has been removed.
11
+
12
+ This effectively gives you the embeddings.
13
+
14
+ NB: You may want to flatten the embeddings, as it'll be of shape `(1, 20248, 1, 1)` otherwise.
15
+
16
+ # Example
17
+
18
+ ```python
19
+ import onnxruntime
20
+ from PIL import Image
21
+ from torchvision import transforms
22
+
23
+
24
+ def load_and_preprocess_image(image_path):
25
+ # Define the same preprocessing as used in training
26
+ preprocess = transforms.Compose(
27
+ [
28
+ transforms.Resize(256),
29
+ transforms.CenterCrop(224),
30
+ transforms.ToTensor(),
31
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
32
+ ]
33
+ )
34
+
35
+ # Open the image file
36
+ img = Image.open(image_path)
37
+
38
+ # Preprocess the image
39
+ img_preprocessed = preprocess(img)
40
+
41
+ # Add batch dimension
42
+ return img_preprocessed.unsqueeze(0).numpy()
43
+
44
+
45
+ onnx_model_path = "resnet50_embeddings.onnx"
46
+
47
+ session = onnxruntime.InferenceSession(onnx_model_path)
48
+
49
+ input_name = session.get_inputs()[0].name
50
+
51
+ # Load and preprocess an image (replace with your image path)
52
+ image_path = "disco-ball.jpg"
53
+ input_data = load_and_preprocess_image(image_path)
54
+
55
+ # Run inference
56
+ outputs = session.run(None, {input_name: input_data})
57
+
58
+ # The output should be a single tensor (the embeddings)
59
+ embeddings = outputs[0]
60
+
61
+ # Flatten the embeddings
62
+ embeddings = embeddings.reshape(embeddings.shape[0], -1)
63
+ ```