louiecerv's picture
fixed the handling of the streamed response
9dc0324
raw
history blame
6.79 kB
import os
import base64
import requests
import streamlit as st
import json
import tempfile
if "stream" not in st.session_state:
st.session_state.stream = True
api_key = os.getenv("NVIDIA_VISION_API_KEY")
MODEL_ID = "meta/llama-3.2-90b-vision-instruct"
invoke_url = "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct/chat/completions"
# Function to encode the image
def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
def main():
st.title(f"Multimodal Image Analysis with {MODEL_ID}")
# Display about section
about_text = """Prof. Louie F. Cervantes, M. Eng. (Information Engineering)
CCS 229 - Intelligent Systems
Department of Computer Science
College of Information and Communications Technology
West Visayas State University
"""
with st.expander("About"):
st.text(about_text)
st.write("Upload an image and select the image analysis task.")
# File upload for image
uploaded_image = st.file_uploader("Upload an Image", type=["jpg", "jpeg", "png"])
temp_file_path = None
if uploaded_image is not None:
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
temp_file.write(uploaded_image.getvalue())
temp_file_path = temp_file.name
# Encode image as Base64
with open(temp_file_path, "rb") as f:
base64_image = base64.b64encode(f.read()).decode()
# Display the uploaded image
st.image(uploaded_image, caption="Uploaded Image", use_container_width=True)
# List of image analysis tasks
analysis_tasks = [
"Scene Analysis: Describe the scene depicted in the image. Identify the objects present, their spatial relationships, and any actions taking place.",
"Object Detection and Classification: Identify and classify all objects present in the image. Provide detailed descriptions of each object, including its size, shape, color, and texture.",
"Image Captioning: Generate a concise and accurate caption that describes the content of the image.",
"Visual Question Answering: Answer specific questions about the image, such as 'What color is the car?' or 'How many people are in the image?'",
"Image Similarity Search: Given a query image, find similar images from a large dataset based on visual features.",
"Image Segmentation: Segment the image into different regions corresponding to objects or areas of interest.",
"Optical Character Recognition (OCR): Extract text from the image, such as printed or handwritten text.",
"Diagram Understanding: Analyze a diagram (e.g., flowchart, circuit diagram) and extract its structure and meaning.",
"Art Analysis: Describe the artistic style, subject matter, and emotional impact of an image.",
"Medical Image Analysis: Analyze medical images (e.g., X-rays, MRIs) to detect abnormalities or diagnose diseases."
]
# Task selection dropdown
selected_task = st.selectbox("Select an image analysis task:", [""] + analysis_tasks)
# Checkbox for streaming
stream = st.checkbox("Begin streaming the AI response as soon as it is available.", value=st.session_state.stream)
if st.button("Generate Response"):
if not api_key:
st.error("API key not found. Please set the NVIDIA_VISION_API_KEY environment variable.")
return
if uploaded_image is None:
st.error("Please upload an image.")
return
if not selected_task:
st.error("Please select an image analysis task.")
return
# Headers for the API call
headers = {
"Authorization": f"Bearer {api_key}",
"Accept": "text/event-stream" if stream else "application/json"
}
# Prepare the multimodal prompt
payload = {
"model": MODEL_ID,
"messages": [
{
"role": "user",
"content": f'{selected_task} <img src="data:image/png;base64,{base64_image}" />'
}
],
"max_tokens": 512,
"temperature": 1.0,
"top_p": 1.0,
"stream": stream
}
try:
with st.spinner("Processing..."):
response = requests.post(
invoke_url,
headers=headers,
json=payload,
stream=stream
)
response.raise_for_status() # Raise exception for HTTP errors
if stream:
# Handle streaming response
response_container = st.empty()
content = ""
for chunk in response.iter_lines(decode_unicode=True):
if chunk:
if "[DONE]" in chunk:
# Handle the end chunk
st.write("Response generation complete.")
break
# Check if the chunk is a JSON string
elif chunk.startswith("data:"):
chunk = chunk[5:].strip() # Remove the "data:" prefix
try:
if len(chunk) > 0:
chunk_dict = json.loads(chunk)
if "choices" in chunk_dict and chunk_dict["choices"]:
delta_content = chunk_dict["choices"][0]["delta"]["content"]
content += delta_content
response_container.write(content)
except json.JSONDecodeError as e:
st.error(f"Error parsing JSON: {e}")
else:
# Handle non-streaming response
content = response.json()
content_string = content.get("choices", [{}])[0].get("message", {}).get("content", "")
st.write(f"AI Response: {content_string}")
st.success("Response generated!")
except requests.exceptions.RequestException as e:
st.error(f"An error occurred while making the API call: {e}")
except Exception as e:
st.error(f"An unexpected error occurred: {e}")
finally:
# Clean up temporary file
if temp_file_path and os.path.exists(temp_file_path):
os.remove(temp_file_path)
if __name__ == "__main__":
main()