Spaces:
Sleeping
Sleeping
fetched and marged changes made in the remote repo
Browse files- .gitattributes +1 -0
- app.py +16 -34
- detr_fine_tuning_custom_dataset.ipynb +3 -0
- model.py +12 -7
.gitattributes
CHANGED
@@ -35,3 +35,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
models/*.safetensors filter=lfs diff=lfs merge=lfs -text
|
37 |
corrected[[:space:]]model[[:space:]]path filter=lfs diff=lfs merge=lfs -text
|
|
|
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
models/*.safetensors filter=lfs diff=lfs merge=lfs -text
|
37 |
corrected[[:space:]]model[[:space:]]path filter=lfs diff=lfs merge=lfs -text
|
38 |
+
*.ipynb filter=lfs diff=lfs merge=lfs -text
|
app.py
CHANGED
@@ -5,28 +5,28 @@ import matplotlib.pyplot as plt
|
|
5 |
import io
|
6 |
from model import load_model, get_val_transform # Import functions from model.py
|
7 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
-
# Load model and image processor
|
10 |
-
model, image_processor = load_model()
|
11 |
val_transform = get_val_transform()
|
12 |
|
13 |
# Define colors for bounding boxes
|
14 |
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
|
15 |
[0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
|
16 |
|
17 |
-
def
|
18 |
-
"""Preprocess image using the validation transform."""
|
19 |
-
numpy_image = np.array(image)
|
20 |
-
transformed = val_transform(image=numpy_image, category=[]) # No bounding boxes needed for input
|
21 |
-
return transformed["image"]
|
22 |
-
|
23 |
-
def postprocess(image, outputs, threshold):
|
24 |
-
"""Postprocess outputs to draw bounding boxes on the image."""
|
25 |
plt.figure(figsize=(12, 8))
|
26 |
-
plt.imshow(
|
27 |
ax = plt.gca()
|
28 |
|
29 |
-
for
|
|
|
|
|
|
|
30 |
if score > threshold:
|
31 |
color = COLORS[hash(label) % len(COLORS)]
|
32 |
ax.add_patch(
|
@@ -40,10 +40,8 @@ def postprocess(image, outputs, threshold):
|
|
40 |
box[0], box[1] - 5, text, fontsize=10,
|
41 |
bbox=dict(facecolor='yellow', alpha=0.5, edgecolor='none')
|
42 |
)
|
43 |
-
|
44 |
plt.axis('off')
|
45 |
|
46 |
-
# Convert matplotlib figure to PIL image
|
47 |
buf = io.BytesIO()
|
48 |
plt.savefig(buf, bbox_inches='tight', dpi=100)
|
49 |
buf.seek(0)
|
@@ -51,25 +49,9 @@ def postprocess(image, outputs, threshold):
|
|
51 |
return Image.open(buf)
|
52 |
|
53 |
def detect(image, threshold=0.5):
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
# Convert to tensor for the model
|
59 |
-
inputs = image_processor(images=processed_image, return_tensors="pt")
|
60 |
-
|
61 |
-
# Run the model
|
62 |
-
outputs = model(**inputs)
|
63 |
-
|
64 |
-
# Convert the outputs to a more usable format
|
65 |
-
results = {
|
66 |
-
"boxes": outputs.logits.argmax(dim=-1).tolist(), # Replace with actual box extraction logic
|
67 |
-
"labels": outputs.logits.argmax(dim=-1).tolist(), # Replace with actual label extraction logic
|
68 |
-
"scores": outputs.scores.tolist(), # Replace with actual score extraction logic
|
69 |
-
}
|
70 |
-
|
71 |
-
# Postprocess and return the annotated image
|
72 |
-
return postprocess(image, results, threshold)
|
73 |
|
74 |
# Build the Gradio app
|
75 |
with gr.Blocks() as demo:
|
@@ -84,7 +66,7 @@ with gr.Blocks() as demo:
|
|
84 |
with gr.Row():
|
85 |
image_input = gr.Image(label="Input Image", type="pil")
|
86 |
threshold_slider = gr.Slider(
|
87 |
-
minimum=0.0, maximum=1.0, step=0.05, value=0.
|
88 |
)
|
89 |
|
90 |
output_image = gr.Image(label="Output Prediction", type="pil")
|
|
|
5 |
import io
|
6 |
from model import load_model, get_val_transform # Import functions from model.py
|
7 |
import numpy as np
|
8 |
+
<<<<<<< HEAD
|
9 |
+
=======
|
10 |
+
|
11 |
+
# Load the model on GPU if available
|
12 |
+
model = load_model(device=0 if torch.cuda.is_available() else -1)
|
13 |
+
>>>>>>> 16e22fbe27ebb36b6090c462c63a4d127310b2b8
|
14 |
|
|
|
|
|
15 |
val_transform = get_val_transform()
|
16 |
|
17 |
# Define colors for bounding boxes
|
18 |
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
|
19 |
[0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
|
20 |
|
21 |
+
def get_output_figure(pil_img, results, threshold):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
plt.figure(figsize=(12, 8))
|
23 |
+
plt.imshow(pil_img)
|
24 |
ax = plt.gca()
|
25 |
|
26 |
+
for result in results:
|
27 |
+
score = result['score']
|
28 |
+
label = result['label']
|
29 |
+
box = list(result['box'].values())
|
30 |
if score > threshold:
|
31 |
color = COLORS[hash(label) % len(COLORS)]
|
32 |
ax.add_patch(
|
|
|
40 |
box[0], box[1] - 5, text, fontsize=10,
|
41 |
bbox=dict(facecolor='yellow', alpha=0.5, edgecolor='none')
|
42 |
)
|
|
|
43 |
plt.axis('off')
|
44 |
|
|
|
45 |
buf = io.BytesIO()
|
46 |
plt.savefig(buf, bbox_inches='tight', dpi=100)
|
47 |
buf.seek(0)
|
|
|
49 |
return Image.open(buf)
|
50 |
|
51 |
def detect(image, threshold=0.5):
|
52 |
+
results = model(image)
|
53 |
+
output_image = get_output_figure(image, results, threshold)
|
54 |
+
return output_image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
# Build the Gradio app
|
57 |
with gr.Blocks() as demo:
|
|
|
66 |
with gr.Row():
|
67 |
image_input = gr.Image(label="Input Image", type="pil")
|
68 |
threshold_slider = gr.Slider(
|
69 |
+
minimum=0.0, maximum=1.0, step=0.05, value=0.5, label="Detection Threshold"
|
70 |
)
|
71 |
|
72 |
output_image = gr.Image(label="Output Prediction", type="pil")
|
detr_fine_tuning_custom_dataset.ipynb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b92462743cc8ab359bd1a790f7213b6494fb0726cbf70c25982060ff75e7b06e
|
3 |
+
size 7982824
|
model.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import albumentations as A
|
2 |
from transformers import AutoModelForObjectDetection, AutoImageProcessor
|
|
|
3 |
|
4 |
|
5 |
# Mapping for labels and IDs
|
@@ -52,11 +53,15 @@ def get_val_transform():
|
|
52 |
)
|
53 |
|
54 |
# Load the model
|
55 |
-
def load_model():
|
56 |
-
|
57 |
-
|
58 |
-
model
|
59 |
-
|
60 |
-
|
|
|
|
|
|
|
|
|
61 |
)
|
62 |
-
return
|
|
|
1 |
import albumentations as A
|
2 |
from transformers import AutoModelForObjectDetection, AutoImageProcessor
|
3 |
+
from transformers import pipeline
|
4 |
|
5 |
|
6 |
# Mapping for labels and IDs
|
|
|
53 |
)
|
54 |
|
55 |
# Load the model
|
56 |
+
def load_model(device: int = -1):
|
57 |
+
"""
|
58 |
+
Load the DETR model pipeline.
|
59 |
+
:param device: Specify device to load the model (-1 for CPU, 0 for GPU).
|
60 |
+
:return: Hugging Face object-detection pipeline.
|
61 |
+
"""
|
62 |
+
model_pipeline = pipeline(
|
63 |
+
"object-detection",
|
64 |
+
model="sergiopaniego/detr-resnet-50-dc5-fashionpedia-finetuned",
|
65 |
+
device=device
|
66 |
)
|
67 |
+
return model_pipeline
|