boostedhug commited on
Commit
b1c89e5
·
verified ·
1 Parent(s): 5b8e275

fixed recent traceback

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  from transformers import DetrImageProcessor, DetrForObjectDetection
3
- from PIL import Image
4
  import torch
5
 
6
  # Load the DETR model and processor
@@ -11,7 +11,7 @@ model = DetrForObjectDetection.from_pretrained(model_name)
11
  def detect_accident(image):
12
  """Process an image and detect traffic accidents using the DETR model."""
13
  # Preprocess the input image
14
- inputs = processor(images=image, return_tensors="pt")
15
 
16
  # Get model predictions
17
  outputs = model(**inputs)
@@ -35,8 +35,8 @@ def detect_accident(image):
35
  # Define the Gradio interface
36
  iface = gr.Interface(
37
  fn=detect_accident,
38
- inputs=gr.inputs.Image(type="pil"),
39
- outputs=gr.outputs.Image(type="pil"),
40
  title="Traffic Accident Detection",
41
  description="Upload an image to detect traffic accidents using the DETR model."
42
  )
@@ -55,6 +55,7 @@ iface.launch()
55
 
56
 
57
 
 
58
  # from fastapi import FastAPI, File, UploadFile
59
  # from fastapi.responses import JSONResponse
60
  # from fastapi.middleware.cors import CORSMiddleware
 
1
  import gradio as gr
2
  from transformers import DetrImageProcessor, DetrForObjectDetection
3
+ from PIL import Image, ImageDraw
4
  import torch
5
 
6
  # Load the DETR model and processor
 
11
  def detect_accident(image):
12
  """Process an image and detect traffic accidents using the DETR model."""
13
  # Preprocess the input image
14
+ inputs = processor(images=image, return_tensors="pt", size={"longest_edge": 800})
15
 
16
  # Get model predictions
17
  outputs = model(**inputs)
 
35
  # Define the Gradio interface
36
  iface = gr.Interface(
37
  fn=detect_accident,
38
+ inputs=gr.Image(type="pil"),
39
+ outputs=gr.Image(type="pil"),
40
  title="Traffic Accident Detection",
41
  description="Upload an image to detect traffic accidents using the DETR model."
42
  )
 
55
 
56
 
57
 
58
+
59
  # from fastapi import FastAPI, File, UploadFile
60
  # from fastapi.responses import JSONResponse
61
  # from fastapi.middleware.cors import CORSMiddleware