percept-image / app.py
mdanish's picture
ensure ksims and kscores are correct size
6ca0b8b
raw
history blame
4.21 kB
import streamlit as st
from PIL import Image
import numpy as np
import torch
from sklearn.utils.extmath import softmax
import open_clip
#from transformers import CLIPProcessor, CLIPModel
knnpath = '20241204-ams-no-env-open_clip_ViT-H-14-378-quickgelu.npz'
clip_model_name = 'ViT-H-14-378-quickgelu'
pretrained_name = 'dfn5b'
categories = ['walkability', 'bikeability', 'pleasantness', 'greenness', 'safety']
# Set page config
st.set_page_config(
page_title="Percept",
layout="wide"
)
debug = False
#st.write("Available models:", open_clip.list_models())
@st.cache_resource
def load_model():
"""Load the OpenCLIP model and return model and processor"""
model, _, preprocess = open_clip.create_model_and_transforms(
clip_model_name, pretrained=pretrained_name
)
tokenizer = open_clip.get_tokenizer(clip_model_name)
return model, preprocess, tokenizer
def process_image(image, preprocess):
"""Process image and return tensor"""
if isinstance(image, str):
# If image is a URL
response = requests.get(image)
image = Image.open(BytesIO(response.content))
# Ensure image is in RGB mode
if image.mode != 'RGB':
image = image.convert('RGB')
processed_image = preprocess(image).unsqueeze(0)
return processed_image
def knn_get_score(knn, k, cat, vec):
allvecs = knn[f'{cat}_vecs']
if debug: st.write('allvecs.shape', allvecs.shape)
scores = knn[f'{cat}_scores']
if debug: st.write('scores.shape', scores.shape)
# Compute cosine similiarity of vec against allvecs
# (both are already normalized)
cos_sim_table = vec @ allvecs.T
if debug: st.write('cos_sim_table.shape', cos_sim_table.shape)
# Get sorted array indices by similiarity in descending order
sortinds = np.flip(np.argsort(cos_sim_table, axis=1), axis=1)
if debug: st.write('sortinds.shape', sortinds.shape)
# Get corresponding scores for the sorted vectors
kscores = scores[sortinds][:,:k]
if debug: st.write('kscores.shape', kscores.shape)
# Get actual sorted similiarity scores
# (line copied from clip_retrieval_knn.py even though sortinds.shape[0] == 1 here)
ksims = cos_sim_table[np.expand_dims(np.arange(sortinds.shape[0]), axis=1), sortinds]
ksims = ksims[:,:k]
if debug: st.write('ksims.shape', ksims.shape)
# Apply normalization after exponential formula
ksims = softmax(10**ksims)
# Weighted sum
kweightedscore = np.sum(kscores * ksims)
return kweightedscore
@st.cache_resource
def load_knn():
return np.load(knnpath)
def main():
st.title("Percept: Human Perception of Street View Image Analyzer")
try:
with st.spinner('Loading CLIP model... This may take a moment.'):
model, preprocess, tokenizer = load_model()
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
except Exception as e:
st.error(f"Error loading model: {str(e)}")
st.info("Please make sure you have enough memory and the correct dependencies installed.")
with st.spinner('Loading KNN model... This may take a moment.'):
knn = load_knn()
if debug: st.write(knn['walkability_vecs'].shape)
file = st.file_uploader('Upload An Image')
if file:
try:
image = Image.open(file)
st.image(image, caption="Uploaded Image", width=400)
# Process image
with st.spinner('Processing image...'):
processed_image = process_image(image, preprocess)
processed_image = processed_image.to(device)
# Encode into CLIP vector
with torch.no_grad():
vec = model.encode_image(processed_image)
# Normalize vector
vec /= vec.norm(dim=-1, keepdim=True)
if debug: st.write(vec.shape)
vec = vec.numpy()
k = 40
for cat in categories:
st.write(cat, f'rating = {knn_get_score(knn, k, cat, vec):.1f}')
except Exception as e:
st.error(f"Error processing image: {str(e)}")
if __name__ == "__main__":
main()