karthickn commited on
Commit
50526a2
·
1 Parent(s): 16b71de

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -5
app.py CHANGED
@@ -4,12 +4,16 @@ from transformers import DetrImageProcessor, DetrForObjectDetection
4
  from color import Color
5
  from color_wheel import ColorWheel
6
  from PIL import ImageDraw, ImageFont
 
7
 
8
- processor = DetrImageProcessor.from_pretrained('facebook/detr-resnet-50')
9
- model = DetrForObjectDetection.from_pretrained('facebook/detr-resnet-50')
 
 
 
 
 
10
 
11
-
12
- def process_image(image, margin):
13
  if image is None:
14
  yield [None, None, None]
15
  return
@@ -30,6 +34,7 @@ def process_image(image, margin):
30
  index = 0
31
  gallery = []
32
  labels = []
 
33
  drawImage = image.copy()
34
  draw = ImageDraw.Draw(drawImage)
35
  for score, label, box in zip(results['scores'], results['labels'], results['boxes']):
@@ -43,8 +48,14 @@ def process_image(image, margin):
43
  draw.rectangle([(box[0], box[1]), (box[2], box[3])], outline=colors[index], width=4)
44
  gallery.append(image.crop((box[0], box[1], box[2], box[3])))
45
  labels.append(model.config.id2label[label.item()])
 
 
46
  index += 1
47
- yield [drawImage, gallery, ','.join(labels)]
 
 
 
 
48
 
49
  app = gr.Interface(
50
  title='Object Detection for Image',
@@ -52,14 +63,17 @@ app = gr.Interface(
52
  inputs=[
53
  gr.Image(type='pil'),
54
  gr.Slider(maximum=100, step=1, label='margin'),
 
55
  ],
56
  outputs=[
57
  gr.Image(label='boxes', type='pil'),
58
  gr.Gallery(label='gallery', columns=8, height=140),
 
59
  gr.Textbox(label='text'),
60
  ],
61
  allow_flagging='never',
62
  examples=[['examples/Wild.jpg', 0], ['examples/Football-Match.jpg', 0]],
 
63
  #cache_examples=False
64
  )
65
  app.queue(concurrency_count=20)
 
4
  from color import Color
5
  from color_wheel import ColorWheel
6
  from PIL import ImageDraw, ImageFont
7
+ import numpy as np
8
 
9
+ def process_image(image, margin, model):
10
+ if model=='detr-resnet-101':
11
+ processor = DetrImageProcessor.from_pretrained('facebook/detr-resnet-101')
12
+ model = DetrForObjectDetection.from_pretrained('facebook/detr-resnet-101')
13
+ else:
14
+ processor = DetrImageProcessor.from_pretrained('facebook/detr-resnet-50')
15
+ model = DetrForObjectDetection.from_pretrained('facebook/detr-resnet-50')
16
 
 
 
17
  if image is None:
18
  yield [None, None, None]
19
  return
 
34
  index = 0
35
  gallery = []
36
  labels = []
37
+ newlabel = {}
38
  drawImage = image.copy()
39
  draw = ImageDraw.Draw(drawImage)
40
  for score, label, box in zip(results['scores'], results['labels'], results['boxes']):
 
48
  draw.rectangle([(box[0], box[1]), (box[2], box[3])], outline=colors[index], width=4)
49
  gallery.append(image.crop((box[0], box[1], box[2], box[3])))
50
  labels.append(model.config.id2label[label.item()])
51
+ print(model.config.id2label[label.item()])
52
+ newlabel[model.config.id2label[label.item()]] = 1
53
  index += 1
54
+ print('----------')
55
+ print(labels)
56
+ print(results['scores'])
57
+ print(newlabel)
58
+ yield [drawImage, gallery, newlabel, ','.join(labels)]
59
 
60
  app = gr.Interface(
61
  title='Object Detection for Image',
 
63
  inputs=[
64
  gr.Image(type='pil'),
65
  gr.Slider(maximum=100, step=1, label='margin'),
66
+ gr.Radio(["detr-resnet-50", "detr-resnet-101"], value="detr-resnet-50", label="Select the model")
67
  ],
68
  outputs=[
69
  gr.Image(label='boxes', type='pil'),
70
  gr.Gallery(label='gallery', columns=8, height=140),
71
+ gr.Label(label='scores'),
72
  gr.Textbox(label='text'),
73
  ],
74
  allow_flagging='never',
75
  examples=[['examples/Wild.jpg', 0], ['examples/Football-Match.jpg', 0]],
76
+ # theme='HaleyCH/HaleyCH_Theme',
77
  #cache_examples=False
78
  )
79
  app.queue(concurrency_count=20)