rifatramadhani commited on
Commit
ff22380
·
1 Parent(s): f58205e
age_estimation/age_estimation.py CHANGED
@@ -5,6 +5,8 @@ from utils.image_utils import preprocess_image, get_image_from_input
5
  from utils.face_detector import load_face_detector
6
  from .predict import predict_age
7
 
 
 
8
 
9
  def age_estimation(input_type, uploaded_image, image_url, base64_string):
10
  """
@@ -32,9 +34,10 @@ def age_estimation(input_type, uploaded_image, image_url, base64_string):
32
 
33
  try:
34
  face_detector = load_face_detector()
35
-
36
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
- model = load_model(device)
 
38
 
39
  # Preprocess the image (convert PIL to numpy, ensure RGB)
40
  processed_image = preprocess_image(image)
 
5
  from utils.face_detector import load_face_detector
6
  from .predict import predict_age
7
 
8
+ model = load_model("cpu") # Load the model on CPU by default
9
+
10
 
11
  def age_estimation(input_type, uploaded_image, image_url, base64_string):
12
  """
 
34
 
35
  try:
36
  face_detector = load_face_detector()
37
+ global model
38
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39
+ print(f"Using device: {device}") # Debug print
40
+ model = model.to(device)
41
 
42
  # Preprocess the image (convert PIL to numpy, ensure RGB)
43
  processed_image = preprocess_image(image)
app.py CHANGED
@@ -1,24 +1,13 @@
1
  # Standard library imports
2
  import os
3
  import gradio as gr
4
- import spaces
5
-
6
-
7
 
8
  # Local imports
9
- from age_estimation.age_estimation import age_estimation as _age_estimation
10
- from detection.face_detection import face_detection
11
- from detection.object_detection import object_detection as _object_detection
12
  from utils.ui_utils import update_input_visibility
13
 
14
- @spaces.GPU
15
- def age_estimation(input_type, uploaded_image, image_url, base64_string):
16
- return _age_estimation(input_type, uploaded_image, image_url, base64_string)
17
-
18
- @spaces.GPU
19
- def object_detection(input_type, uploaded_image, image_url, base64_string):
20
- return _object_detection(input_type, uploaded_image, image_url, base64_string)
21
-
22
  with gr.Blocks() as demo:
23
  # Add a title to the interface
24
  gr.Markdown("# Computer Vision Tools")
 
1
  # Standard library imports
2
  import os
3
  import gradio as gr
 
 
 
4
 
5
  # Local imports
6
+ from detection.object_detection import object_detection
7
+ from age_estimation.age_estimation import age_estimation
8
+ from detection.face_detection import face_detection
9
  from utils.ui_utils import update_input_visibility
10
 
 
 
 
 
 
 
 
 
11
  with gr.Blocks() as demo:
12
  # Add a title to the interface
13
  gr.Markdown("# Computer Vision Tools")
detection/object_detection.py CHANGED
@@ -11,12 +11,25 @@ YOLO_MODEL = "yolo11n.pt"
11
 
12
  # Load the YOLO model globally to avoid reloading on each function call
13
  # Using a common pre-trained YOLO nano model ('yolov8n.pt')
14
- try:
15
- model = YOLO(YOLO_MODEL)
16
- print("YOLO model loaded successfully.")
17
- except Exception as e:
18
- print(f"Error loading YOLO model: {e}")
19
- model = None # Set model to None if loading fails
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
 
22
  def object_detection(input_type, uploaded_image, image_url, base64_string):
@@ -34,6 +47,8 @@ def object_detection(input_type, uploaded_image, image_url, base64_string):
34
  - numpy.ndarray: The image with detected objects drawn on it, or None if an error occurred or model not loaded.
35
  - dict: A dictionary containing the raw detection data (bounding boxes, classes, scores), or None.
36
  """
 
 
37
  if model is None:
38
  print("YOLO model is not loaded. Cannot perform object detection.")
39
  return None, None # Return None for both outputs
 
11
 
12
  # Load the YOLO model globally to avoid reloading on each function call
13
  # Using a common pre-trained YOLO nano model ('yolov8n.pt')
14
+ model = None
15
+
16
+
17
+ def load_model():
18
+ """
19
+ Load the YOLO model (YOLOv11 nano) for object detection.
20
+ This function is called once at the start to avoid reloading the model multiple times.
21
+ """
22
+ try:
23
+ global model
24
+ if model is not None:
25
+ print("YOLO model already loaded.")
26
+ return
27
+ model = YOLO(YOLO_MODEL)
28
+ print("YOLO model loaded successfully.")
29
+ return
30
+ except Exception as e:
31
+ print(f"Error loading YOLO model: {e}")
32
+ return
33
 
34
 
35
  def object_detection(input_type, uploaded_image, image_url, base64_string):
 
47
  - numpy.ndarray: The image with detected objects drawn on it, or None if an error occurred or model not loaded.
48
  - dict: A dictionary containing the raw detection data (bounding boxes, classes, scores), or None.
49
  """
50
+ load_model() # Load the model if not already loaded
51
+
52
  if model is None:
53
  print("YOLO model is not loaded. Cannot perform object detection.")
54
  return None, None # Return None for both outputs