Kelechi Osuji commited on
Commit
9f75794
·
2 Parent(s): aade841 16e22fb

fetched and marged changes made in the remote repo

Browse files
Files changed (4) hide show
  1. .gitattributes +1 -0
  2. app.py +16 -34
  3. detr_fine_tuning_custom_dataset.ipynb +3 -0
  4. model.py +12 -7
.gitattributes CHANGED
@@ -35,3 +35,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  models/*.safetensors filter=lfs diff=lfs merge=lfs -text
37
  corrected[[:space:]]model[[:space:]]path filter=lfs diff=lfs merge=lfs -text
 
 
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  models/*.safetensors filter=lfs diff=lfs merge=lfs -text
37
  corrected[[:space:]]model[[:space:]]path filter=lfs diff=lfs merge=lfs -text
38
+ *.ipynb filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -5,28 +5,28 @@ import matplotlib.pyplot as plt
5
  import io
6
  from model import load_model, get_val_transform # Import functions from model.py
7
  import numpy as np
 
 
 
 
 
 
8
 
9
- # Load model and image processor
10
- model, image_processor = load_model()
11
  val_transform = get_val_transform()
12
 
13
  # Define colors for bounding boxes
14
  COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
15
  [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
16
 
17
- def preprocess(image):
18
- """Preprocess image using the validation transform."""
19
- numpy_image = np.array(image)
20
- transformed = val_transform(image=numpy_image, category=[]) # No bounding boxes needed for input
21
- return transformed["image"]
22
-
23
- def postprocess(image, outputs, threshold):
24
- """Postprocess outputs to draw bounding boxes on the image."""
25
  plt.figure(figsize=(12, 8))
26
- plt.imshow(image)
27
  ax = plt.gca()
28
 
29
- for box, label, score in zip(outputs['boxes'], outputs['labels'], outputs['scores']):
 
 
 
30
  if score > threshold:
31
  color = COLORS[hash(label) % len(COLORS)]
32
  ax.add_patch(
@@ -40,10 +40,8 @@ def postprocess(image, outputs, threshold):
40
  box[0], box[1] - 5, text, fontsize=10,
41
  bbox=dict(facecolor='yellow', alpha=0.5, edgecolor='none')
42
  )
43
-
44
  plt.axis('off')
45
 
46
- # Convert matplotlib figure to PIL image
47
  buf = io.BytesIO()
48
  plt.savefig(buf, bbox_inches='tight', dpi=100)
49
  buf.seek(0)
@@ -51,25 +49,9 @@ def postprocess(image, outputs, threshold):
51
  return Image.open(buf)
52
 
53
  def detect(image, threshold=0.5):
54
- """Run the detection pipeline."""
55
- # Preprocess the image
56
- processed_image = preprocess(image)
57
-
58
- # Convert to tensor for the model
59
- inputs = image_processor(images=processed_image, return_tensors="pt")
60
-
61
- # Run the model
62
- outputs = model(**inputs)
63
-
64
- # Convert the outputs to a more usable format
65
- results = {
66
- "boxes": outputs.logits.argmax(dim=-1).tolist(), # Replace with actual box extraction logic
67
- "labels": outputs.logits.argmax(dim=-1).tolist(), # Replace with actual label extraction logic
68
- "scores": outputs.scores.tolist(), # Replace with actual score extraction logic
69
- }
70
-
71
- # Postprocess and return the annotated image
72
- return postprocess(image, results, threshold)
73
 
74
  # Build the Gradio app
75
  with gr.Blocks() as demo:
@@ -84,7 +66,7 @@ with gr.Blocks() as demo:
84
  with gr.Row():
85
  image_input = gr.Image(label="Input Image", type="pil")
86
  threshold_slider = gr.Slider(
87
- minimum=0.0, maximum=1.0, step=0.05, value=0.7, label="Detection Threshold"
88
  )
89
 
90
  output_image = gr.Image(label="Output Prediction", type="pil")
 
5
  import io
6
  from model import load_model, get_val_transform # Import functions from model.py
7
  import numpy as np
8
+ <<<<<<< HEAD
9
+ =======
10
+
11
+ # Load the model on GPU if available
12
+ model = load_model(device=0 if torch.cuda.is_available() else -1)
13
+ >>>>>>> 16e22fbe27ebb36b6090c462c63a4d127310b2b8
14
 
 
 
15
  val_transform = get_val_transform()
16
 
17
  # Define colors for bounding boxes
18
  COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
19
  [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
20
 
21
+ def get_output_figure(pil_img, results, threshold):
 
 
 
 
 
 
 
22
  plt.figure(figsize=(12, 8))
23
+ plt.imshow(pil_img)
24
  ax = plt.gca()
25
 
26
+ for result in results:
27
+ score = result['score']
28
+ label = result['label']
29
+ box = list(result['box'].values())
30
  if score > threshold:
31
  color = COLORS[hash(label) % len(COLORS)]
32
  ax.add_patch(
 
40
  box[0], box[1] - 5, text, fontsize=10,
41
  bbox=dict(facecolor='yellow', alpha=0.5, edgecolor='none')
42
  )
 
43
  plt.axis('off')
44
 
 
45
  buf = io.BytesIO()
46
  plt.savefig(buf, bbox_inches='tight', dpi=100)
47
  buf.seek(0)
 
49
  return Image.open(buf)
50
 
51
  def detect(image, threshold=0.5):
52
+ results = model(image)
53
+ output_image = get_output_figure(image, results, threshold)
54
+ return output_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  # Build the Gradio app
57
  with gr.Blocks() as demo:
 
66
  with gr.Row():
67
  image_input = gr.Image(label="Input Image", type="pil")
68
  threshold_slider = gr.Slider(
69
+ minimum=0.0, maximum=1.0, step=0.05, value=0.5, label="Detection Threshold"
70
  )
71
 
72
  output_image = gr.Image(label="Output Prediction", type="pil")
detr_fine_tuning_custom_dataset.ipynb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b92462743cc8ab359bd1a790f7213b6494fb0726cbf70c25982060ff75e7b06e
3
+ size 7982824
model.py CHANGED
@@ -1,5 +1,6 @@
1
  import albumentations as A
2
  from transformers import AutoModelForObjectDetection, AutoImageProcessor
 
3
 
4
 
5
  # Mapping for labels and IDs
@@ -52,11 +53,15 @@ def get_val_transform():
52
  )
53
 
54
  # Load the model
55
- def load_model():
56
- model_path = "model" # Path to your saved model
57
- image_processor = AutoImageProcessor.from_pretrained(model_path)
58
- model = AutoModelForObjectDetection.from_pretrained(
59
- model_path, # This will automatically use model.safetensors
60
- ignore_mismatched_sizes=True
 
 
 
 
61
  )
62
- return model, image_processor
 
1
  import albumentations as A
2
  from transformers import AutoModelForObjectDetection, AutoImageProcessor
3
+ from transformers import pipeline
4
 
5
 
6
  # Mapping for labels and IDs
 
53
  )
54
 
55
  # Load the model
56
+ def load_model(device: int = -1):
57
+ """
58
+ Load the DETR model pipeline.
59
+ :param device: Specify device to load the model (-1 for CPU, 0 for GPU).
60
+ :return: Hugging Face object-detection pipeline.
61
+ """
62
+ model_pipeline = pipeline(
63
+ "object-detection",
64
+ model="sergiopaniego/detr-resnet-50-dc5-fashionpedia-finetuned",
65
+ device=device
66
  )
67
+ return model_pipeline