panda1835 commited on
Commit
0db6636
·
verified ·
1 Parent(s): 15eb587

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -2
app.py CHANGED
@@ -10,7 +10,7 @@ import torchvision.transforms as T
10
  from PIL import Image
11
  import gradio as gr
12
  from datetime import datetime
13
-
14
  import models
15
 
16
  print(f"Is CUDA available: {torch.cuda.is_available()}")
@@ -66,10 +66,45 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
66
  linear_model_name = 'linear_model.pt'
67
  classify_model = models.LinearClassifier(input_dim=768, output_dim=num_classes)
68
  classify_model.load_state_dict(torch.load(linear_model_name))
69
-
70
  k = 5
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  def classify(image):
 
 
 
 
 
 
 
 
73
  embedding = extract_embedding(image)
74
  embedding = embedding['embedding']
75
  output = classify_model(torch.Tensor(embedding).to(device))
 
10
  from PIL import Image
11
  import gradio as gr
12
  from datetime import datetime
13
+ from ultralytics import YOLO
14
  import models
15
 
16
  print(f"Is CUDA available: {torch.cuda.is_available()}")
 
66
  linear_model_name = 'linear_model.pt'
67
  classify_model = models.LinearClassifier(input_dim=768, output_dim=num_classes)
68
  classify_model.load_state_dict(torch.load(linear_model_name))
69
+ detect_model = YOLO('yolov8m_2023-10-23_best.pt')
70
  k = 5
71
 
72
+ def detect(image):
73
+ results = detect_model(image, conf=0.1)
74
+ # Get the current time
75
+ current_time = datetime.now()
76
+ # Format the current time as a string
77
+ formatted_time = current_time.strftime("%Y-%m-%d %H:%M:%S")
78
+ print(formatted_time)
79
+ try:
80
+ results = results[0].boxes.xyxy[0].cpu().numpy()
81
+ top = int(results[1])
82
+ left = int(results[0])
83
+ width = int(results[2] - results[0])
84
+ height = int(results[3] - results[1])
85
+ return {
86
+ "top": top,
87
+ "left": left,
88
+ "width": width,
89
+ "height": height
90
+ }
91
+ except:
92
+ return {
93
+ "top": 0,
94
+ "left": 0,
95
+ "width": 0,
96
+ "height": 0
97
+ }
98
+
99
  def classify(image):
100
+ detection = detect(image)
101
+
102
+ if detection["top"] == 0 and detection["left"] == 0 and detection["width"] == 0 and detection["height"] == 0:
103
+ return {}
104
+ # Crop the image
105
+ image = image.crop((detection['left'], detection['top'], detection['left'] + detection['width'], detection['top'] + detection['height']))
106
+
107
+ # Perform the embedding search
108
  embedding = extract_embedding(image)
109
  embedding = embedding['embedding']
110
  output = classify_model(torch.Tensor(embedding).to(device))