|
import streamlit as st |
|
import os |
|
import glob |
|
import base64 |
|
import json |
|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
import matplotlib.image as mpimg |
|
from langchain_openai import ChatOpenAI |
|
from langchain_core.pydantic_v1 import BaseModel, Field |
|
from langchain_core.messages import HumanMessage, SystemMessage |
|
from langchain_core.output_parsers import JsonOutputParser |
|
from langchain_core.runnables import chain |
|
from PIL import Image as PILImage |
|
from io import BytesIO |
|
|
|
|
|
st.title("Vehicle Information Extraction from Images") |
|
|
|
|
|
openai_api_key = st.text_input("Enter your OpenAI API Key:", type="password") |
|
|
|
|
|
if openai_api_key: |
|
os.environ["OPENAI_API_KEY"] = openai_api_key |
|
|
|
|
|
class Vehicle(BaseModel): |
|
Type: str = Field(..., examples=["Car", "Truck", "Motorcycle", 'Bus', 'Van'], description="The type of the vehicle.") |
|
License: str = Field(..., description="The license plate number of the vehicle.") |
|
Make: str = Field(..., examples=["Toyota", "Honda", "Ford", "Suzuki"], description="The Make of the vehicle.") |
|
Model: str = Field(..., examples=["Corolla", "Civic", "F-150"], description="The Model of the vehicle.") |
|
Color: str = Field(..., example=["Red", "Blue", "Black", "White"], description="Return the color of the vehicle.") |
|
|
|
|
|
parser = JsonOutputParser(pydantic_object=Vehicle) |
|
instructions = parser.get_format_instructions() |
|
|
|
|
|
def image_encoding(inputs): |
|
"""Load and convert image to base64 encoding""" |
|
with open(inputs["image_path"], "rb") as image_file: |
|
image_base64 = base64.b64encode(image_file.read()).decode("utf-8") |
|
return {"image": image_base64} |
|
|
|
|
|
def display_image_grid(image_paths, rows=2, cols=3, figsize=(10, 7)): |
|
fig = plt.figure(figsize=figsize) |
|
max_images = rows * cols |
|
image_paths = image_paths[:max_images] |
|
|
|
for idx, path in enumerate(image_paths): |
|
ax = fig.add_subplot(rows, cols, idx + 1) |
|
img = mpimg.imread(path) |
|
ax.imshow(img) |
|
ax.axis('off') |
|
filename = path.split('/')[-1] |
|
ax.set_title(filename) |
|
|
|
plt.tight_layout() |
|
st.pyplot(fig) |
|
|
|
|
|
@chain |
|
def prompt(inputs): |
|
prompt = [ |
|
SystemMessage(content="""You are an AI assistant whose job is to inspect an image and provide the desired information from the image. If the desired field is not clear or not well detected, return None for this field. Do not try to guess."""), |
|
HumanMessage( |
|
content=[{"type": "text", "text": "Examine the main vehicle type, license plate number, make, model and color."}, |
|
{"type": "text", "text": instructions}, |
|
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{inputs['image']}", "detail": "low"}}] |
|
) |
|
] |
|
return prompt |
|
|
|
|
|
@chain |
|
def MLLM_response(inputs): |
|
model: ChatOpenAI = ChatOpenAI(model="gpt-4o-2024-08-06", temperature=0.0, max_tokens=1024) |
|
output = model.invoke(inputs) |
|
return output.content |
|
|
|
|
|
pipeline = image_encoding | prompt | MLLM_response | parser |
|
|
|
|
|
st.header("Upload a Vehicle Image for Information Extraction") |
|
|
|
uploaded_image = st.file_uploader("Choose a JPEG image", type="jpeg") |
|
|
|
if uploaded_image is not None: |
|
|
|
image = PILImage.open(uploaded_image) |
|
st.image(image, caption="Uploaded Image", use_column_width=True) |
|
|
|
|
|
image_path = "/tmp/uploaded_image.jpeg" |
|
with open(image_path, "wb") as f: |
|
f.write(uploaded_image.getbuffer()) |
|
|
|
|
|
output = pipeline.invoke({"image_path": image_path}) |
|
|
|
|
|
st.subheader("Extracted Vehicle Information") |
|
st.json(output) |
|
|
|
|
|
img_dir = "/content/images" |
|
image_paths = glob.glob(os.path.join(img_dir, "*.jpeg")) |
|
display_image_grid(image_paths) |
|
|
|
|
|
st.sidebar.header("Batch Image Upload") |
|
|
|
batch_images = st.sidebar.file_uploader("Upload Images", type="jpeg", accept_multiple_files=True) |
|
|
|
if batch_images: |
|
batch_input = [{"image_path": f"/tmp/{file.name}"} for file in batch_images] |
|
for file in batch_images: |
|
with open(f"/tmp/{file.name}", "wb") as f: |
|
f.write(file.getbuffer()) |
|
|
|
|
|
batch_output = pipeline.batch(batch_input) |
|
df = pd.DataFrame(batch_output) |
|
st.dataframe(df) |
|
|
|
|
|
image_paths = [f"/tmp/{file.name}" for file in batch_images] |
|
display_image_grid(image_paths) |
|
|