xingqiang commited on
Commit
4e8a01a
·
1 Parent(s): 90baed8

update models

Browse files
Files changed (1) hide show
  1. model.py +20 -11
model.py CHANGED
@@ -1,22 +1,31 @@
1
- from transformers import AutoFeatureExtractor, AutoModelForObjectDetection
 
2
  import torch
3
- from config import MODEL_NAME
4
 
5
 
6
  class RadarDetectionModel:
7
  def __init__(self):
8
- self.feature_extractor = AutoFeatureExtractor.from_pretrained(
9
- MODEL_NAME)
10
- self.model = AutoModelForObjectDetection.from_pretrained(MODEL_NAME)
 
11
  self.model.eval()
12
 
 
 
 
 
 
 
 
13
  @torch.no_grad()
14
  def detect(self, image):
15
- inputs = self.feature_extractor(images=image, return_tensors="pt")
16
- outputs = self.model(**inputs)
17
 
18
- target_sizes = torch.tensor([image.size[::-1]])
19
- results = self.feature_extractor.post_process_object_detection(
20
- outputs, threshold=0.5, target_sizes=target_sizes)[0]
 
21
 
22
- return results
 
1
+ from transformers import AutoConfig, AutoModelForObjectDetection
2
+ from PIL import Image
3
  import torch
 
4
 
5
 
6
  class RadarDetectionModel:
7
  def __init__(self):
8
+ self.config = AutoConfig.from_pretrained(
9
+ "Extremely4606/paligemma_9_19")
10
+ self.model = AutoModelForObjectDetection.from_pretrained(
11
+ "Extremely4606/paligemma_9_19")
12
  self.model.eval()
13
 
14
+ def preprocess_image(self, image):
15
+ # 这里需要根据模型的具体要求来处理图像
16
+ # 这只是一个示例,可能需要调整
17
+ image = image.resize((224, 224))
18
+ image = torch.tensor(np.array(image)).permute(2, 0, 1).float() / 255.0
19
+ return image.unsqueeze(0)
20
+
21
  @torch.no_grad()
22
  def detect(self, image):
23
+ inputs = self.preprocess_image(image)
24
+ outputs = self.model(inputs)
25
 
26
+ # 这里可能需要根据模型的输出格式进行调整
27
+ boxes = outputs.pred_boxes[0]
28
+ scores = outputs.scores[0]
29
+ labels = outputs.labels[0]
30
 
31
+ return {"boxes": boxes, "scores": scores, "labels": labels}