Sa-m commited on
Commit
edb3cf3
·
1 Parent(s): 929e161

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -1
app.py CHANGED
@@ -15,7 +15,28 @@ def detect(inp):
15
 
16
 
17
 
18
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
 
21
  inp = gr.inputs.Image(type="filepath", label="Input")
 
15
 
16
 
17
 
18
+ def custom(path_or_model='path/to/model.pt', autoshape=True):
19
+ """custom mode
20
+ Arguments (3 options):
21
+ path_or_model (str): 'path/to/model.pt'
22
+ path_or_model (dict): torch.load('path/to/model.pt')
23
+ path_or_model (nn.Module): torch.load('path/to/model.pt')['model']
24
+ Returns:
25
+ pytorch model
26
+ """
27
+ model = torch.load(path_or_model) if isinstance(path_or_model, str) else path_or_model # load checkpoint
28
+ if isinstance(model, dict):
29
+ model = model['ema' if model.get('ema') else 'model'] # load model
30
+
31
+ hub_model = Model(model.yaml).to(next(model.parameters()).device) # create
32
+ hub_model.load_state_dict(model.float().state_dict()) # load state_dict
33
+ hub_model.names = model.names # class names
34
+ if autoshape:
35
+ hub_model = hub_model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS
36
+ device = select_device('0' if torch.cuda.is_available() else 'cpu') # default to GPU if available
37
+ return hub_model.to(device)
38
+
39
+ model = custom(path_or_model='best.pt')
40
 
41
 
42
  inp = gr.inputs.Image(type="filepath", label="Input")