import torch from transformers import ( PaliGemmaProcessor, PaliGemmaForConditionalGeneration, ) import streamlit as st from PIL import Image import os # write access token in secrets token = os.environ.get('HF_TOKEN') # paligemma model model_id = "google/paligemma2-3b-pt-896" @st.cache_resource def model_setup(model_id): model = PaliGemmaForConditionalGeneration.from_pretrained(model_id,torch_dtype=torch.bfloat16,device_map="auto",token=token).eval() processor = PaliGemmaProcessor.from_pretrained(model_id,token=token) return model,processor def runModel(prompt,image): model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(torch.bfloat16).to(model.device) input_len = model_inputs["input_ids"].shape[-1] with torch.inference_mode(): generation = model.generate(**model_inputs, max_new_tokens=1000, do_sample=False) generation = generation[0][input_len:] return processor.decode(generation, skip_special_tokens=True) def initialize(): # initialize chat history st.session_state.messages = [] ### load model model,processor = model_setup(model_id) ### upload a file uploaded_file = st.file_uploader("Choose an image",on_change=initialize) if uploaded_file: st.image(uploaded_file) image = Image.open(uploaded_file).convert("RGB") # tasks task = st.radio( "Task", tuple(['Caption','OCR','Segment','Enter your prompt']), horizontal=True) # display chat messages from history on app rerun for message in st.session_state.messages: with st.chat_message(message["role"]): st.markdown(message["content"]) if task == 'Enter your prompt': if prompt := st.chat_input("Type here!",key="question"): # display user message in chat message container with st.chat_message("user"): st.markdown(prompt) # Add user message to chat history st.session_state.messages.append({"role": "user", "content": prompt}) # run the VLM response = runModel(prompt,image) # display assistant response in chat message container with st.chat_message("assistant"): st.markdown(response) # Add assistant response to chat history st.session_state.messages.append({"role": "assistant", "content": response}) else: # display user message in chat message container with st.chat_message("user"): st.markdown(task) # Add user message to chat history st.session_state.messages.append({"role": "user", "content": task}) # run the VLM response = runModel(task,image) # display assistant response in chat message container with st.chat_message("assistant"): st.markdown(response) # Add assistant response to chat history st.session_state.messages.append({"role": "assistant", "content": response})