GoodBaiBai88 commited on
Commit
988ef80
1 Parent(s): 49696db

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +48 -0
README.md CHANGED
@@ -1,3 +1,51 @@
1
  ---
2
  license: apache-2.0
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
+ metrics:
4
+ - accuracy
5
+ pipeline_tag: image-feature-extraction
6
+ tags:
7
+ - 3D medical CLIP
8
+ - Image-text retrieval
9
  ---
10
+
11
+ M3D-CLIP is a 3D medical CLIP model, which aligns vision and language through contrastive loss on [M3D-Cap](https://huggingface.co/datasets/GoodBaiBai88/M3D-Cap) dataset.
12
+ The vision encoder uses 3D ViT with 32*256*256 image size and 4*16*16 patch size.
13
+ The text encoder utilizes a pre-trained BERT as initialization.
14
+
15
+ # Quickstart
16
+
17
+ ```python
18
+ device = torch.device("cuda") # or cpu
19
+
20
+ tokenizer = AutoTokenizer.from_pretrained(
21
+ "GoodBaiBai88/M3D-CLIP",
22
+ model_max_length=512,
23
+ padding_side="right",
24
+ use_fast=False
25
+ )
26
+ model = AutoModel.from_pretrained(
27
+ "GoodBaiBai88/M3D-CLIP",
28
+ trust_remote_code=True
29
+ )
30
+ model = model.to(device=device)
31
+
32
+ # Prepare your 3D medical image:
33
+ # 1. The image shape needs to be processed as 1*32*256*256, consider resize and other methods.
34
+ # 2. The image needs to be normalized to 0-1, consider Min-Max Normalization.
35
+ # 3. The image format needs to be converted to .npy
36
+ # 4. Although we did not train on 2D images, in theory, the 2D image can be interpolated to the shape of 1*32*256*256 for input.
37
+
38
+ image_path = ""
39
+ input_txt = ""
40
+
41
+ text_tensor = tokenizer(input_txt, return_tensors="pt")
42
+ input_id = text_tensor["input_ids"].to(device=device)
43
+ attention_mask = text_tensor["attention_mask"].to(device=device)
44
+ image = np.load(image_path).to(device=device)
45
+
46
+ with torch.inference_mode():
47
+ image_features = model.encode_image(image)[:, 0]
48
+ text_features = model.encode_text(input_id, attention_mask)[:, 0]
49
+
50
+ ```
51
+