File size: 2,218 Bytes
d5345e2
cffabcf
 
1366c30
 
aa31199
cffabcf
 
 
 
 
 
 
 
 
 
58a320e
cffabcf
 
 
 
 
1366c30
d5345e2
3de06ed
 
8f1a737
aa31199
e810c3c
 
0047d73
f15bc50
e810c3c
 
 
 
aa31199
d5345e2
 
ba9c45e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cffabcf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import streamlit as st
import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, CLIPProcessor
from medclip.modeling_hybrid_clip import FlaxHybridCLIP

@st.cache(allow_output_mutation=True)
def load_model():
    model = FlaxHybridCLIP.from_pretrained("flax-community/medclip-roco")
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    return model, processor

@st.cache(allow_output_mutation=True)
def load_image_embeddings():
    embeddings_df = pd.read_pickle('feature_store/image_embeddings_large.pkl')
    image_embeds = np.stack(embeddings_df['image_embedding'])
    image_files = np.asarray(embeddings_df['files'].tolist())
    return image_files, image_embeds

k = 5
img_dir = './images'

st.title("MedCLIP 🩺")
st.image("./assets/logo.png", width=100)
st.markdown("""Search for medical images with natural language powered by a CLIP model [[Model Card]](https://huggingface.co/flax-community/medclip-roco) finetuned on the
 [Radiology Objects in COntext (ROCO) dataset](https://github.com/razorx89/roco-dataset).""")
st.markdown("""Example queries:
* `ultrasound scans`
* `pathology`
* `pancreatic carcinoma`""")

image_list, image_embeddings = load_image_embeddings()
model, processor = load_model()

query = st.text_input("Enter your query here:")

if st.button("Search"):
    with st.spinner(f"Searching ROCO test set for {query}..."):
        inputs = processor(text=[query], images=None, return_tensors="jax", padding=True)

        query_embedding = model.get_text_features(**inputs)
        query_embedding = np.asarray(query_embedding)
        query_embedding = query_embedding / np.linalg.norm(query_embedding, axis=-1, keepdims=True)
        dot_prod = np.sum(np.multiply(query_embedding, image_embeddings), axis=1)
        topk_images = dot_prod.argsort()[-k:]
        matching_images = image_list[topk_images]
        top_scores = 1. - dot_prod[topk_images]
        #show images
        for img_path, score in zip(matching_images, top_scores):
            img = plt.imread(os.path.join(img_dir, img_path))
            st.image(img)
            st.write(f"{img_path} ({score:.2f})", help="score")