Víctor Sáez commited on
Commit
6ecfb14
·
1 Parent(s): 53e14a8

Update Gradio interface and add arial.ttf tracked via LFS

Browse files
Files changed (3) hide show
  1. .gitattributes +1 -0
  2. .gitignore +43 -0
  3. app.py +25 -11
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/fonts/arial.ttf filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *.so
5
+
6
+ # Virtual environment
7
+ venv/
8
+ .env/
9
+ .env.*
10
+
11
+ # PyCharm
12
+ .idea/
13
+ *.iml
14
+
15
+ # Model cache (Hugging Face, PyTorch, etc.)
16
+ ~/.cache/
17
+ .cache/
18
+ *.ckpt
19
+ *.pt
20
+ *.bin
21
+ *.safetensors
22
+
23
+ # System files
24
+ .DS_Store
25
+ Thumbs.db
26
+
27
+ # Logs
28
+ *.log
29
+
30
+ # Jupyter Notebooks (if any outputs get messy)
31
+ .ipynb_checkpoints/
32
+
33
+ # Test files or temp
34
+ *.tmp
35
+ *.bak
36
+
37
+ # Fonts or large assets
38
+ assets/fonts/*.ttf
39
+
40
+ # Optional: don't track test images
41
+ test_images/
42
+ *.jpg
43
+ *.png
app.py CHANGED
@@ -1,15 +1,25 @@
1
  import gradio as gr
 
2
  from PIL import Image, ImageDraw, ImageFont
3
  from transformers import DetrImageProcessor, DetrForObjectDetection
4
- import torch
 
5
 
6
  # Load DETR model and processor from Hugging Face
7
  model_name = "facebook/detr-resnet-50"
8
  processor = DetrImageProcessor.from_pretrained(model_name)
9
  model = DetrForObjectDetection.from_pretrained(model_name)
10
 
11
- # Load default font
12
- font = ImageFont.load_default()
 
 
 
 
 
 
 
 
13
 
14
  # Main function: takes an image and returns it with boxes and labels
15
  def detect_objects(image):
@@ -41,20 +51,24 @@ def detect_objects(image):
41
  # Set background rectangle for text
42
  text_background = [
43
  box[0], box[1] - text_height,
44
- box[0] + text_width, box[1]
45
  ]
46
  draw.rectangle(text_background, fill="black") # Background
47
  draw.text((box[0], box[1] - text_height), label_text, fill="white", font=font)
48
 
49
  return image_with_boxes
50
 
51
- # Gradio interface
52
- app = gr.Interface(
53
- fn=detect_objects,
54
- inputs=gr.Image(type="pil"),
55
- outputs=gr.Image()
56
- )
57
 
58
- # Run app
 
 
 
 
 
 
 
 
 
 
59
  if __name__ == "__main__":
60
  app.launch()
 
1
  import gradio as gr
2
+ import torch
3
  from PIL import Image, ImageDraw, ImageFont
4
  from transformers import DetrImageProcessor, DetrForObjectDetection
5
+ from pathlib import Path
6
+
7
 
8
  # Load DETR model and processor from Hugging Face
9
  model_name = "facebook/detr-resnet-50"
10
  processor = DetrImageProcessor.from_pretrained(model_name)
11
  model = DetrForObjectDetection.from_pretrained(model_name)
12
 
13
+ # Load font
14
+ font_path = Path("assets/fonts/arial.ttf")
15
+ if not font_path.exists():
16
+ # If the font file does not exist, use the default PIL font
17
+ print(f"Font file {font_path} not found. Using default font.")
18
+ font = ImageFont.load_default()
19
+ else:
20
+ font = ImageFont.truetype(str(font_path), size=100)
21
+
22
+ print(f"CUDA is available: {torch.cuda.is_available()}")
23
 
24
  # Main function: takes an image and returns it with boxes and labels
25
  def detect_objects(image):
 
51
  # Set background rectangle for text
52
  text_background = [
53
  box[0], box[1] - text_height,
54
+ box[0] + text_width, box[1]
55
  ]
56
  draw.rectangle(text_background, fill="black") # Background
57
  draw.text((box[0], box[1] - text_height), label_text, fill="white", font=font)
58
 
59
  return image_with_boxes
60
 
 
 
 
 
 
 
61
 
62
+ with gr.Blocks() as app:
63
+ with gr.Row():
64
+ gr.Markdown("## Object Detection App\nUpload an image to detect objects using Facebook's DETR model.")
65
+ with gr.Row():
66
+ input_image = gr.Image(type="pil", label="Input Image")
67
+ output_image = gr.Image(label="Detected Objects")
68
+ with gr.Row():
69
+ button = gr.Button("Detect Objects")
70
+
71
+ button.click(fn=detect_objects, inputs=input_image, outputs=output_image)
72
+
73
  if __name__ == "__main__":
74
  app.launch()