|
import streamlit as st |
|
import os |
|
import glob |
|
import base64 |
|
import json |
|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
import matplotlib.image as mpimg |
|
from langchain_openai import ChatOpenAI |
|
from langchain_core.pydantic_v1 import BaseModel, Field |
|
from langchain_core.messages import HumanMessage, SystemMessage |
|
from langchain_core.output_parsers import JsonOutputParser |
|
from langchain_core.runnables import chain |
|
from PIL import Image as PILImage |
|
from io import BytesIO |
|
|
|
# Streamlit title |
|
st.title("Vehicle Information Extraction from Images") |
|
|
|
# Prompt user for OpenAI API key |
|
openai_api_key = st.text_input("Enter your OpenAI API Key:", type="password") |
|
|
|
# Set the OpenAI API key if provided |
|
if openai_api_key: |
|
os.environ["OPENAI_API_KEY"] = openai_api_key |
|
|
|
# Vehicle class (same as in the original code) |
|
class Vehicle(BaseModel): |
|
Type: str = Field(..., examples=["Car", "Truck", "Motorcycle", 'Bus', 'Van'], description="The type of the vehicle.") |
|
License: str = Field(..., description="The license plate number of the vehicle.") |
|
Make: str = Field(..., examples=["Toyota", "Honda", "Ford", "Suzuki"], description="The Make of the vehicle.") |
|
Model: str = Field(..., examples=["Corolla", "Civic", "F-150"], description="The Model of the vehicle.") |
|
Color: str = Field(..., example=["Red", "Blue", "Black", "White"], description="Return the color of the vehicle.") |
|
Year: str = Field(None, description="The year of the vehicle.") |
|
Condition: str = Field(None, description="The condition of the vehicle.") |
|
Logo: str = Field(None, description="The visible logo of the vehicle, if applicable.") |
|
Damage: str = Field(None, description="Any visible damage or wear and tear on the vehicle.") |
|
Region: str = Field(None, description="Region or country based on the license plate or clues from the image.") |
|
PlateType: str = Field(None, description="Type of license plate, e.g., government, personal.") |
|
|
|
# Parser for vehicle details |
|
parser = JsonOutputParser(pydantic_object=Vehicle) |
|
instructions = parser.get_format_instructions() |
|
|
|
# Image encoding function (for base64 encoding) |
|
def image_encoding(inputs): |
|
"""Load and convert image to base64 encoding""" |
|
with open(inputs["image_path"], "rb") as image_file: |
|
image_base64 = base64.b64encode(image_file.read()).decode("utf-8") |
|
return {"image": image_base64} |
|
|
|
# Image display in grid (for multiple images) |
|
def display_image_grid(image_paths, rows=2, cols=3, figsize=(10, 7)): |
|
fig = plt.figure(figsize=figsize) |
|
max_images = rows * cols |
|
image_paths = image_paths[:max_images] |
|
|
|
for idx, path in enumerate(image_paths): |
|
ax = fig.add_subplot(rows, cols, idx + 1) |
|
img = mpimg.imread(path) |
|
ax.imshow(img) |
|
ax.axis('off') |
|
filename = path.split('/')[-1] |
|
ax.set_title(filename) |
|
|
|
plt.tight_layout() |
|
st.pyplot(fig) |
|
|
|
# Create the prompt for the AI model |
|
@chain |
|
def prompt(inputs): |
|
prompt = [ |
|
SystemMessage(content="""You are an AI assistant tasked with extracting detailed information from a vehicle image. Please extract the following details: |
|
- Vehicle type (e.g., Car, Truck, Bus) |
|
- License plate number and type (if identifiable, such as personal, commercial, government) |
|
- Vehicle make, model, and year (e.g., 2020 Toyota Corolla) |
|
- Vehicle color and condition (e.g., Red, well-maintained, damaged) |
|
- Any visible brand logos or distinguishing marks (e.g., Tesla logo) |
|
- Details of any visible damage (e.g., scratches, dents) |
|
- Vehicle’s region or country (based on the license plate or other clues) |
|
If some details are unclear or not visible, return None for those fields. Do not guess or provide inaccurate information.""" |
|
), |
|
HumanMessage( |
|
content=[ |
|
{"type": "text", "text": "Analyze the vehicle in the image and extract as many details as possible, including type, license plate, make, model, year, condition, damage, etc."}, |
|
{"type": "text", "text": instructions}, # include any other format instructions here |
|
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{inputs['image']}", "detail": "low"}} |
|
] |
|
) |
|
] |
|
return prompt |
|
|
|
# Invoke the model for extracting vehicle details |
|
@chain |
|
def MLLM_response(inputs): |
|
model: ChatOpenAI = ChatOpenAI(model="gpt-4o-2024-08-06", temperature=0.0, max_tokens=1024) |
|
output = model.invoke(inputs) |
|
return output.content |
|
|
|
# The complete pipeline for extracting vehicle details |
|
pipeline = image_encoding | prompt | MLLM_response | parser |
|
|
|
# Streamlit Interface for uploading images and showing results |
|
st.header("Upload a Vehicle Image for Information Extraction") |
|
|
|
uploaded_image = st.file_uploader("Choose a JPEG image", type="jpeg") |
|
|
|
if uploaded_image is not None: |
|
# Display the uploaded image |
|
image = PILImage.open(uploaded_image) |
|
st.image(image, caption="Uploaded Image", use_column_width=True) |
|
|
|
# Convert the uploaded image to base64 |
|
image_path = "/tmp/uploaded_image.jpeg" |
|
with open(image_path, "wb") as f: |
|
f.write(uploaded_image.getbuffer()) |
|
|
|
# Process the image through the pipeline |
|
output = pipeline.invoke({"image_path": image_path}) |
|
|
|
# Show the results in a user-friendly format |
|
st.subheader("Extracted Vehicle Information") |
|
st.json(output) |
|
|
|
# Optionally, display more vehicle images from the folder |
|
img_dir = "/content/images" |
|
image_paths = glob.glob(os.path.join(img_dir, "*.jpeg")) |
|
display_image_grid(image_paths) |
|
|
|
# You can also allow users to upload and process a batch of images |
|
st.sidebar.header("Batch Image Upload") |
|
|
|
batch_images = st.sidebar.file_uploader("Upload Images", type="jpeg", accept_multiple_files=True) |
|
|
|
if batch_images: |
|
batch_input = [{"image_path": f"/tmp/{file.name}"} for file in batch_images] |
|
for file in batch_images: |
|
with open(f"/tmp/{file.name}", "wb") as f: |
|
f.write(file.getbuffer()) |
|
|
|
# Process the batch and display the results in a DataFrame |
|
batch_output = pipeline.batch(batch_input) |
|
df = pd.DataFrame(batch_output) |
|
st.dataframe(df) |
|
|
|
# Show images in a grid |
|
image_paths = [f"/tmp/{file.name}" for file in batch_images] |
|
display_image_grid(image_paths) |