File size: 6,334 Bytes
8f9de0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import streamlit as st
import os
import glob
import base64
import json
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from langchain_openai import ChatOpenAI
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.runnables import chain
from PIL import Image as PILImage
from io import BytesIO

# Streamlit title
st.title("Vehicle Information Extraction from Images")

# Prompt user for OpenAI API key
openai_api_key = st.text_input("Enter your OpenAI API Key:", type="password")

# Set the OpenAI API key if provided
if openai_api_key:
    os.environ["OPENAI_API_KEY"] = openai_api_key

# Vehicle class (same as in the original code)
class Vehicle(BaseModel):
    Type: str = Field(..., examples=["Car", "Truck", "Motorcycle", 'Bus', 'Van'], description="The type of the vehicle.")
    License: str = Field(..., description="The license plate number of the vehicle.")
    Make: str = Field(..., examples=["Toyota", "Honda", "Ford", "Suzuki"], description="The Make of the vehicle.")
    Model: str = Field(..., examples=["Corolla", "Civic", "F-150"], description="The Model of the vehicle.")
    Color: str = Field(..., example=["Red", "Blue", "Black", "White"], description="Return the color of the vehicle.")
    Year: str = Field(None, description="The year of the vehicle.")
    Condition: str = Field(None, description="The condition of the vehicle.")
    Logo: str = Field(None, description="The visible logo of the vehicle, if applicable.")
    Damage: str = Field(None, description="Any visible damage or wear and tear on the vehicle.")
    Region: str = Field(None, description="Region or country based on the license plate or clues from the image.")
    PlateType: str = Field(None, description="Type of license plate, e.g., government, personal.")

# Parser for vehicle details
parser = JsonOutputParser(pydantic_object=Vehicle)
instructions = parser.get_format_instructions()

# Image encoding function (for base64 encoding)
def image_encoding(inputs):
    """Load and convert image to base64 encoding"""
    with open(inputs["image_path"], "rb") as image_file:
        image_base64 = base64.b64encode(image_file.read()).decode("utf-8")
    return {"image": image_base64}

# Image display in grid (for multiple images)
def display_image_grid(image_paths, rows=2, cols=3, figsize=(10, 7)):
    fig = plt.figure(figsize=figsize)
    max_images = rows * cols
    image_paths = image_paths[:max_images]

    for idx, path in enumerate(image_paths):
        ax = fig.add_subplot(rows, cols, idx + 1)
        img = mpimg.imread(path)
        ax.imshow(img)
        ax.axis('off')
        filename = path.split('/')[-1]
        ax.set_title(filename)

    plt.tight_layout()
    st.pyplot(fig)

# Create the prompt for the AI model
@chain
def prompt(inputs):
    prompt = [
        SystemMessage(content="""You are an AI assistant tasked with extracting detailed information from a vehicle image. Please extract the following details:
        - Vehicle type (e.g., Car, Truck, Bus)
        - License plate number and type (if identifiable, such as personal, commercial, government)
        - Vehicle make, model, and year (e.g., 2020 Toyota Corolla)
        - Vehicle color and condition (e.g., Red, well-maintained, damaged)
        - Any visible brand logos or distinguishing marks (e.g., Tesla logo)
        - Details of any visible damage (e.g., scratches, dents)
        - Vehicle’s region or country (based on the license plate or other clues)
        If some details are unclear or not visible, return None for those fields. Do not guess or provide inaccurate information."""
        ),
        HumanMessage(
            content=[
                {"type": "text", "text": "Analyze the vehicle in the image and extract as many details as possible, including type, license plate, make, model, year, condition, damage, etc."},
                {"type": "text", "text": instructions},  # include any other format instructions here
                {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{inputs['image']}", "detail": "low"}}
            ]
        )
    ]
    return prompt

# Invoke the model for extracting vehicle details
@chain
def MLLM_response(inputs):
    model: ChatOpenAI = ChatOpenAI(model="gpt-4o-2024-08-06", temperature=0.0, max_tokens=1024)
    output = model.invoke(inputs)
    return output.content

# The complete pipeline for extracting vehicle details
pipeline = image_encoding | prompt | MLLM_response | parser

# Streamlit Interface for uploading images and showing results
st.header("Upload a Vehicle Image for Information Extraction")

uploaded_image = st.file_uploader("Choose a JPEG image", type="jpeg")

if uploaded_image is not None:
    # Display the uploaded image
    image = PILImage.open(uploaded_image)
    st.image(image, caption="Uploaded Image", use_column_width=True)

    # Convert the uploaded image to base64
    image_path = "/tmp/uploaded_image.jpeg"
    with open(image_path, "wb") as f:
        f.write(uploaded_image.getbuffer())

    # Process the image through the pipeline
    output = pipeline.invoke({"image_path": image_path})

    # Show the results in a user-friendly format
    st.subheader("Extracted Vehicle Information")
    st.json(output)

    # Optionally, display more vehicle images from the folder
    img_dir = "/content/images"
    image_paths = glob.glob(os.path.join(img_dir, "*.jpeg"))
    display_image_grid(image_paths)

# You can also allow users to upload and process a batch of images
st.sidebar.header("Batch Image Upload")

batch_images = st.sidebar.file_uploader("Upload Images", type="jpeg", accept_multiple_files=True)

if batch_images:
    batch_input = [{"image_path": f"/tmp/{file.name}"} for file in batch_images]
    for file in batch_images:
        with open(f"/tmp/{file.name}", "wb") as f:
            f.write(file.getbuffer())
    
    # Process the batch and display the results in a DataFrame
    batch_output = pipeline.batch(batch_input)
    df = pd.DataFrame(batch_output)
    st.dataframe(df)

    # Show images in a grid
    image_paths = [f"/tmp/{file.name}" for file in batch_images]
    display_image_grid(image_paths)