louiecerv's picture
fixed the handling of the streamed response
9dc0324
import os
import base64
import requests
import streamlit as st
import json
import tempfile
if "stream" not in st.session_state:
st.session_state.stream = True
api_key = os.getenv("NVIDIA_VISION_API_KEY")
MODEL_ID = "meta/llama-3.2-90b-vision-instruct"
invoke_url = "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct/chat/completions"
# Function to encode the image
def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
def main():
st.title(f"Multimodal Image Analysis with {MODEL_ID}")
# Display about section
about_text = """Prof. Louie F. Cervantes, M. Eng. (Information Engineering)
CCS 229 - Intelligent Systems
Department of Computer Science
College of Information and Communications Technology
West Visayas State University
"""
with st.expander("About"):
st.text(about_text)
st.write("Upload an image and select the image analysis task.")
# File upload for image
uploaded_image = st.file_uploader("Upload an Image", type=["jpg", "jpeg", "png"])
temp_file_path = None
if uploaded_image is not None:
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
temp_file.write(uploaded_image.getvalue())
temp_file_path = temp_file.name
# Encode image as Base64
with open(temp_file_path, "rb") as f:
base64_image = base64.b64encode(f.read()).decode()
# Display the uploaded image
st.image(uploaded_image, caption="Uploaded Image", use_container_width=True)
# List of image analysis tasks
analysis_tasks = [
"Scene Analysis: Describe the scene depicted in the image. Identify the objects present, their spatial relationships, and any actions taking place.",
"Object Detection and Classification: Identify and classify all objects present in the image. Provide detailed descriptions of each object, including its size, shape, color, and texture.",
"Image Captioning: Generate a concise and accurate caption that describes the content of the image.",
"Visual Question Answering: Answer specific questions about the image, such as 'What color is the car?' or 'How many people are in the image?'",
"Image Similarity Search: Given a query image, find similar images from a large dataset based on visual features.",
"Image Segmentation: Segment the image into different regions corresponding to objects or areas of interest.",
"Optical Character Recognition (OCR): Extract text from the image, such as printed or handwritten text.",
"Diagram Understanding: Analyze a diagram (e.g., flowchart, circuit diagram) and extract its structure and meaning.",
"Art Analysis: Describe the artistic style, subject matter, and emotional impact of an image.",
"Medical Image Analysis: Analyze medical images (e.g., X-rays, MRIs) to detect abnormalities or diagnose diseases."
]
# Task selection dropdown
selected_task = st.selectbox("Select an image analysis task:", [""] + analysis_tasks)
# Checkbox for streaming
stream = st.checkbox("Begin streaming the AI response as soon as it is available.", value=st.session_state.stream)
if st.button("Generate Response"):
if not api_key:
st.error("API key not found. Please set the NVIDIA_VISION_API_KEY environment variable.")
return
if uploaded_image is None:
st.error("Please upload an image.")
return
if not selected_task:
st.error("Please select an image analysis task.")
return
# Headers for the API call
headers = {
"Authorization": f"Bearer {api_key}",
"Accept": "text/event-stream" if stream else "application/json"
}
# Prepare the multimodal prompt
payload = {
"model": MODEL_ID,
"messages": [
{
"role": "user",
"content": f'{selected_task} <img src="data:image/png;base64,{base64_image}" />'
}
],
"max_tokens": 512,
"temperature": 1.0,
"top_p": 1.0,
"stream": stream
}
try:
with st.spinner("Processing..."):
response = requests.post(
invoke_url,
headers=headers,
json=payload,
stream=stream
)
response.raise_for_status() # Raise exception for HTTP errors
if stream:
# Handle streaming response
response_container = st.empty()
content = ""
for chunk in response.iter_lines(decode_unicode=True):
if chunk:
if "[DONE]" in chunk:
# Handle the end chunk
st.write("Response generation complete.")
break
# Check if the chunk is a JSON string
elif chunk.startswith("data:"):
chunk = chunk[5:].strip() # Remove the "data:" prefix
try:
if len(chunk) > 0:
chunk_dict = json.loads(chunk)
if "choices" in chunk_dict and chunk_dict["choices"]:
delta_content = chunk_dict["choices"][0]["delta"]["content"]
content += delta_content
response_container.write(content)
except json.JSONDecodeError as e:
st.error(f"Error parsing JSON: {e}")
else:
# Handle non-streaming response
content = response.json()
content_string = content.get("choices", [{}])[0].get("message", {}).get("content", "")
st.write(f"AI Response: {content_string}")
st.success("Response generated!")
except requests.exceptions.RequestException as e:
st.error(f"An error occurred while making the API call: {e}")
except Exception as e:
st.error(f"An unexpected error occurred: {e}")
finally:
# Clean up temporary file
if temp_file_path and os.path.exists(temp_file_path):
os.remove(temp_file_path)
if __name__ == "__main__":
main()