shukdevdatta123 commited on
Commit
4f93965
·
verified ·
1 Parent(s): 38eb7ef

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +131 -0
app.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import glob
4
+ import base64
5
+ import json
6
+ import pandas as pd
7
+ import matplotlib.pyplot as plt
8
+ import matplotlib.image as mpimg
9
+ from langchain_openai import ChatOpenAI
10
+ from langchain_core.pydantic_v1 import BaseModel, Field
11
+ from langchain_core.messages import HumanMessage, SystemMessage
12
+ from langchain_core.output_parsers import JsonOutputParser
13
+ from langchain_core.runnables import chain
14
+ from PIL import Image as PILImage
15
+ from io import BytesIO
16
+
17
+ # Streamlit title
18
+ st.title("Vehicle Information Extraction from Images")
19
+
20
+ # Prompt user for OpenAI API key
21
+ openai_api_key = st.text_input("Enter your OpenAI API Key:", type="password")
22
+
23
+ # Set the OpenAI API key if provided
24
+ if openai_api_key:
25
+ os.environ["OPENAI_API_KEY"] = openai_api_key
26
+
27
+ # Vehicle class (same as in the original code)
28
+ class Vehicle(BaseModel):
29
+ Type: str = Field(..., examples=["Car", "Truck", "Motorcycle", 'Bus', 'Van'], description="The type of the vehicle.")
30
+ License: str = Field(..., description="The license plate number of the vehicle.")
31
+ Make: str = Field(..., examples=["Toyota", "Honda", "Ford", "Suzuki"], description="The Make of the vehicle.")
32
+ Model: str = Field(..., examples=["Corolla", "Civic", "F-150"], description="The Model of the vehicle.")
33
+ Color: str = Field(..., example=["Red", "Blue", "Black", "White"], description="Return the color of the vehicle.")
34
+
35
+ # Parser for vehicle details
36
+ parser = JsonOutputParser(pydantic_object=Vehicle)
37
+ instructions = parser.get_format_instructions()
38
+
39
+ # Image encoding function (for base64 encoding)
40
+ def image_encoding(inputs):
41
+ """Load and convert image to base64 encoding"""
42
+ with open(inputs["image_path"], "rb") as image_file:
43
+ image_base64 = base64.b64encode(image_file.read()).decode("utf-8")
44
+ return {"image": image_base64}
45
+
46
+ # Image display in grid (for multiple images)
47
+ def display_image_grid(image_paths, rows=2, cols=3, figsize=(10, 7)):
48
+ fig = plt.figure(figsize=figsize)
49
+ max_images = rows * cols
50
+ image_paths = image_paths[:max_images]
51
+
52
+ for idx, path in enumerate(image_paths):
53
+ ax = fig.add_subplot(rows, cols, idx + 1)
54
+ img = mpimg.imread(path)
55
+ ax.imshow(img)
56
+ ax.axis('off')
57
+ filename = path.split('/')[-1]
58
+ ax.set_title(filename)
59
+
60
+ plt.tight_layout()
61
+ st.pyplot(fig)
62
+
63
+ # Create the prompt for the AI model
64
+ @chain
65
+ def prompt(inputs):
66
+ prompt = [
67
+ SystemMessage(content="""You are an AI assistant whose job is to inspect an image and provide the desired information from the image. If the desired field is not clear or not well detected, return None for this field. Do not try to guess."""),
68
+ HumanMessage(
69
+ content=[{"type": "text", "text": "Examine the main vehicle type, license plate number, make, model and color."},
70
+ {"type": "text", "text": instructions},
71
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{inputs['image']}", "detail": "low"}}]
72
+ )
73
+ ]
74
+ return prompt
75
+
76
+ # Invoke the model for extracting vehicle details
77
+ @chain
78
+ def MLLM_response(inputs):
79
+ model: ChatOpenAI = ChatOpenAI(model="gpt-4o-2024-08-06", temperature=0.0, max_tokens=1024)
80
+ output = model.invoke(inputs)
81
+ return output.content
82
+
83
+ # The complete pipeline for extracting vehicle details
84
+ pipeline = image_encoding | prompt | MLLM_response | parser
85
+
86
+ # Streamlit Interface for uploading images and showing results
87
+ st.header("Upload a Vehicle Image for Information Extraction")
88
+
89
+ uploaded_image = st.file_uploader("Choose a JPEG image", type="jpeg")
90
+
91
+ if uploaded_image is not None:
92
+ # Display the uploaded image
93
+ image = PILImage.open(uploaded_image)
94
+ st.image(image, caption="Uploaded Image", use_column_width=True)
95
+
96
+ # Convert the uploaded image to base64
97
+ image_path = "/tmp/uploaded_image.jpeg"
98
+ with open(image_path, "wb") as f:
99
+ f.write(uploaded_image.getbuffer())
100
+
101
+ # Process the image through the pipeline
102
+ output = pipeline.invoke({"image_path": image_path})
103
+
104
+ # Show the results in a user-friendly format
105
+ st.subheader("Extracted Vehicle Information")
106
+ st.json(output)
107
+
108
+ # Optionally, display more vehicle images from the folder
109
+ img_dir = "/content/images"
110
+ image_paths = glob.glob(os.path.join(img_dir, "*.jpeg"))
111
+ display_image_grid(image_paths)
112
+
113
+ # You can also allow users to upload and process a batch of images
114
+ st.sidebar.header("Batch Image Upload")
115
+
116
+ batch_images = st.sidebar.file_uploader("Upload Images", type="jpeg", accept_multiple_files=True)
117
+
118
+ if batch_images:
119
+ batch_input = [{"image_path": f"/tmp/{file.name}"} for file in batch_images]
120
+ for file in batch_images:
121
+ with open(f"/tmp/{file.name}", "wb") as f:
122
+ f.write(file.getbuffer())
123
+
124
+ # Process the batch and display the results in a DataFrame
125
+ batch_output = pipeline.batch(batch_input)
126
+ df = pd.DataFrame(batch_output)
127
+ st.dataframe(df)
128
+
129
+ # Show images in a grid
130
+ image_paths = [f"/tmp/{file.name}" for file in batch_images]
131
+ display_image_grid(image_paths)