vumichien commited on
Commit
11e3570
·
verified ·
1 Parent(s): d8c4ccb

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +221 -0
main.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ultralytics import YOLO
2
+ import supervision as sv
3
+ import cv2
4
+ import gradio as gr
5
+ import os
6
+ import numpy as np
7
+ from transformers import AutoProcessor, AutoModelForCausalLM
8
+ import torch
9
+ import requests
10
+ from PIL import Image
11
+ import glob
12
+ import pandas as pd
13
+ import time
14
+ import subprocess
15
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
16
+
17
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
18
+ model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-base-ft", trust_remote_code=True).to(device).eval()
19
+ processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base-ft", trust_remote_code=True)
20
+ onnx_model = YOLO("models/best.onnx", task='detect')
21
+
22
+
23
+ def ends_with_number(s):
24
+ return s[-1].isdigit()
25
+
26
+ def ocr(image, prompt="<OCR>"):
27
+ original_height, original_width = image.shape[:2]
28
+ inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
29
+ generated_ids = model.generate(
30
+ input_ids=inputs["input_ids"],
31
+ pixel_values=inputs["pixel_values"],
32
+ max_new_tokens=1024,
33
+ early_stopping=False,
34
+ do_sample=False,
35
+ num_beams=3
36
+ )
37
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
38
+
39
+ parsed_answer = processor.post_process_generation(
40
+ generated_text,
41
+ task=prompt,
42
+ # image_size=(image.width, image.height)
43
+ image_size=(original_width, original_height)
44
+ )
45
+
46
+ return parsed_answer
47
+
48
+ def parse_detection(detections):
49
+ parsed_rows = []
50
+ for i in range(len(detections.xyxy)):
51
+ x_min = float(detections.xyxy[i][0])
52
+ y_min = float(detections.xyxy[i][1])
53
+ x_max = float(detections.xyxy[i][2])
54
+ y_max = float(detections.xyxy[i][3])
55
+
56
+ width = int(x_max - x_min)
57
+ height = int(y_max - y_min)
58
+
59
+ row = {
60
+ "top": int(y_min),
61
+ "left": int(x_min),
62
+ "width": width,
63
+ "height": height,
64
+ "class_id": ""
65
+ if detections.class_id is None
66
+ else int(detections.class_id[i]),
67
+ "confidence": ""
68
+ if detections.confidence is None
69
+ else float(detections.confidence[i]),
70
+ "tracker_id": ""
71
+ if detections.tracker_id is None
72
+ else int(detections.tracker_id[i]),
73
+ }
74
+
75
+ if hasattr(detections, "data"):
76
+ for key, value in detections.data.items():
77
+ row[key] = (
78
+ str(value[i])
79
+ if hasattr(value, "__getitem__") and value.ndim != 0
80
+ else str(value)
81
+ )
82
+ parsed_rows.append(row)
83
+ return parsed_rows
84
+
85
+
86
+ def cut_and_save_image(image, parsed_detections, output_dir):
87
+ output_path_list = []
88
+
89
+ for i, det in enumerate(parsed_detections):
90
+ # Check if the class is 'mark'
91
+ if det['class_name'] == 'mark':
92
+ top = det['top']
93
+ left = det['left']
94
+ width = det['width']
95
+ height = det['height']
96
+
97
+ # Cut the image
98
+ cut_image = image[top:top + height, left:left + width]
99
+ # Save the image
100
+ output_path = f"{output_dir}/cut_image_{i}.png"
101
+ scaled_image = sv.scale_image(image=cut_image, scale_factor=4)
102
+ cv2.imwrite(output_path, scaled_image, [int(cv2.IMWRITE_JPEG_QUALITY), 500])
103
+ output_path_list.append(output_path)
104
+ return output_path_list
105
+
106
+ def analysis(progress=gr.Progress()):
107
+ progress(0, desc="Analyzing...")
108
+ list_files = glob.glob("output/*.png")
109
+ prompt = "<OCR>"
110
+ results = {}
111
+ for filepath in progress.tqdm(list_files):
112
+ basename = os.path.basename(filepath)
113
+
114
+ image = cv2.imread(filepath)
115
+
116
+ start_time = time.time()
117
+ parsed_answer = ocr(image, prompt)
118
+
119
+ if not ends_with_number(parsed_answer[prompt]):
120
+ parsed_answer[prompt] += "1"
121
+ results[parsed_answer[prompt]] = results.get(parsed_answer[prompt], 0) + 1
122
+ print(basename, parsed_answer[prompt])
123
+ print("Time taken:", time.time() - start_time)
124
+ return pd.DataFrame(results.items(), columns=['Mark', 'Total']).reset_index(drop=False).rename(columns={'index': 'No.'})
125
+
126
+ def inference(
127
+ image_path,
128
+ conf_threshold,
129
+ iou_threshold,
130
+ ):
131
+ """
132
+ YOLOv8 inference function
133
+ Args:
134
+ image_path: Path to the image
135
+ conf_threshold: Confidence threshold
136
+ iou_threshold: IoU threshold
137
+ Returns:
138
+ Rendered image
139
+ """
140
+ image = cv2.imread(image_path)
141
+ original_height, original_width = image.shape[:2]
142
+ print(image.shape)
143
+
144
+ results = onnx_model(image, conf=conf_threshold, iou=iou_threshold)[0]
145
+ detections = sv.Detections.from_ultralytics(results)
146
+ parsed_detections = parse_detection(detections)
147
+ output_dir = "output"
148
+ # Check if the output directory exists, clear all the files inside
149
+ if not os.path.exists(output_dir):
150
+ os.makedirs(output_dir)
151
+ else:
152
+ for f in os.listdir(output_dir):
153
+ os.remove(os.path.join(output_dir, f))
154
+
155
+ output_path_list = cut_and_save_image(image, parsed_detections, output_dir)
156
+
157
+ box_annotator = sv.BoxAnnotator()
158
+ label_annotator = sv.LabelAnnotator(text_position=sv.Position.TOP_LEFT, text_thickness=1, text_padding=2)
159
+ annotated_image = image.copy()
160
+ annotated_image = box_annotator.annotate(
161
+ scene=annotated_image,
162
+ detections=detections
163
+ )
164
+ annotated_image = label_annotator.annotate(scene=annotated_image, detections=detections)
165
+ return annotated_image, output_path_list
166
+
167
+
168
+ TITLE = "<h1 style='font-size: 2.5em; text-align: center;'>Identify objects in construction design</h1>"
169
+ DESCRIPTION = """<p style='font-size: 1.5em; line-height: 1.6em; text-align: left;'>Welcome to the object
170
+ identification application. This tool allows you to upload an image, and it will identify and annotate objects within
171
+ the image. Additionally, you can perform OCR analysis on the detected objects.</p> """
172
+ CSS = """
173
+ #output {
174
+ height: 500px;
175
+ overflow: auto;
176
+ border: 1px solid #ccc;
177
+ }
178
+ h1 {
179
+ text-align: center;
180
+ }
181
+ """
182
+ EXAMPLES = [
183
+ ['examples/train1.png', 0.6, 0.25],
184
+ ['examples/train2.png', 0.9, 0.25],
185
+ ['examples/train3.png', 0.6, 0.25]
186
+ ]
187
+
188
+
189
+ with gr.Blocks(theme=gr.themes.Soft(), css=CSS) as demo:
190
+ gr.HTML(TITLE)
191
+ gr.HTML(DESCRIPTION)
192
+ with gr.Tab(label="Identify objects"):
193
+ with gr.Row():
194
+ input_img = gr.Image(type="filepath", label="Upload Image")
195
+ output_img = gr.Image(type="filepath", label="Output Image")
196
+ with gr.Row():
197
+ with gr.Column():
198
+ conf_thres = gr.Slider(minimum=0.0, maximum=1.0, value=0.6, step=0.05, label="Confidence Threshold")
199
+ with gr.Column():
200
+ iou = gr.Slider(minimum=0.0, maximum=1.0, value=0.25, step=0.05, label="IOU Threshold")
201
+ with gr.Row():
202
+ with gr.Column():
203
+ submit_btn = gr.Button(value="Predict")
204
+ with gr.Column():
205
+ analysis_btn = gr.Button(value="Analysis")
206
+ with gr.Row():
207
+ output_df = gr.Dataframe(label="Results")
208
+ with gr.Row():
209
+ with gr.Accordion("Gallery", open=False):
210
+ gallery = gr.Gallery(label="Detected Mark Object", columns=3)
211
+ submit_btn.click(inference, [input_img, conf_thres, iou], [output_img, gallery])
212
+ analysis_btn.click(analysis, [], [output_df])
213
+ examples = gr.Examples(
214
+ EXAMPLES,
215
+ fn=inference,
216
+ inputs=[input_img, conf_thres, iou],
217
+ outputs=[output_img, gallery],
218
+ cache_examples=False,
219
+ )
220
+
221
+ demo.launch(debug=True)