percept-image / app.py
mdanish's picture
Upload app.py with huggingface_hub
fd8d179 verified
raw
history blame
2.29 kB
import streamlit as st
from PIL import Image
import numpy as np
import open_clip
#from transformers import CLIPProcessor, CLIPModel
knnpath = '20241204-ams-no-env-open_clip_ViT-H-14-378-quickgelu.npz'
clip_model_name = 'ViT-H-14-378-quickgelu'
pretrained_name = 'dfn5b'
# Set page config
st.set_page_config(
page_title="Percept",
layout="wide"
)
#model, preprocess = open_clip.create_model_from_pretrained('hf-hub:laion/CLIP-ViT-g-14-laion2B-s12B-b42K')
#tokenizer = open_clip.get_tokenizer('hf-hub:laion/CLIP-ViT-g-14-laion2B-s12B-b42K')
#model, preprocess = open_clip.create_model_from_pretrained(clip_model_name)
#tokenizer = open_clip.get_tokenizer(clip_model_name)
#st.write("Available models:", open_clip.list_models())
@st.cache_resource
def load_model():
"""Load the OpenCLIP model and return model and processor"""
model, _, preprocess = open_clip.create_model_and_transforms(
clip_model_name, pretrained=pretrained_name
)
tokenizer = open_clip.get_tokenizer(clip_model_name)
return model, preprocess, tokenizer
def process_image(image, preprocess):
"""Process image and return tensor"""
if isinstance(image, str):
# If image is a URL
response = requests.get(image)
image = Image.open(BytesIO(response.content))
# Ensure image is in RGB mode
if image.mode != 'RGB':
image = image.convert('RGB')
processed_image = preprocess(image).unsqueeze(0)
return processed_image
def main():
st.title("OpenCLIP Image Analyzer (ViT-H-14)")
try:
# Load model (uses st.cache_resource)
with st.spinner('Loading model... This may take a moment.'):
model, preprocess, tokenizer = load_model()
except Exception as e:
st.error(f"Error loading model: {str(e)}")
st.info("Please make sure you have enough memory and the correct dependencies installed.")
knn = np.load(knnpath)
st.write(knn['walkability_vecs'].shape)
file = st.file_uploader('Upload An Image')
if file:
try:
with Image.open(file) as img:
st.write(file)
st.write(img.size)
except Exception as e:
st.error(f"Error processing image: {str(e)}")
if __name__ == "__main__":
main()