kevinconka commited on
Commit
e85cd1e
·
1 Parent(s): 79e2714

use lru cache for simplicity

Browse files
Files changed (1) hide show
  1. app.py +8 -21
app.py CHANGED
@@ -7,6 +7,7 @@ Any new model should implement the following functions:
7
 
8
  import os
9
  import glob
 
10
  import spaces
11
  import gradio as gr
12
  from huggingface_hub import get_token
@@ -17,7 +18,6 @@ from utils import (
17
  FlaggedCounter,
18
  )
19
  from flagging import HuggingFaceDatasetSaver
20
- import numpy as np
21
 
22
  import install_private_repos
23
  from seavision import load_model, AHOY
@@ -48,29 +48,16 @@ h1 {
48
  }
49
  """
50
 
51
- rng = np.random.default_rng(0xDEADBEEF)
52
-
53
- class SeaVisionModel:
54
- def __init__(self, model_name="ahoy-RGB-b2"):
55
- self.model_name = model_name
56
- self.model = None
57
-
58
- def get_model(self):
59
- if self.model is None:
60
- self.model = load_model(self.model_name)
61
- return self.model
62
-
63
- def run_inference(self, image):
64
- model = self.get_model()
65
- results = model(image)
66
- return results.draw(image, diameter=4)
67
-
68
- vision_model = SeaVisionModel()
69
 
70
  @spaces.GPU
71
  def inference(image):
72
  """Run inference on image and return annotated image."""
73
- return vision_model.run_inference(image)
 
 
74
 
75
  # Flagging
76
  dataset_name = "SEA-AI/crowdsourced-sea-images"
@@ -106,7 +93,7 @@ with gr.Blocks(theme=theme, css=css, title="SEA.AI Vision Demo") as demo:
106
  examples=glob.glob("examples/*.jpg"),
107
  inputs=img_input,
108
  outputs=img_output,
109
- fn=lambda image: inference(model, image),
110
  cache_examples=True,
111
  )
112
 
 
7
 
8
  import os
9
  import glob
10
+ import functools
11
  import spaces
12
  import gradio as gr
13
  from huggingface_hub import get_token
 
18
  FlaggedCounter,
19
  )
20
  from flagging import HuggingFaceDatasetSaver
 
21
 
22
  import install_private_repos
23
  from seavision import load_model, AHOY
 
48
  }
49
  """
50
 
51
+ @functools.lru_cache(maxsize=1)
52
+ def get_model():
53
+ return load_model("ahoy-RGB-b2")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  @spaces.GPU
56
  def inference(image):
57
  """Run inference on image and return annotated image."""
58
+ model = get_model()
59
+ results = model(image)
60
+ return results.draw(image, diameter=4)
61
 
62
  # Flagging
63
  dataset_name = "SEA-AI/crowdsourced-sea-images"
 
93
  examples=glob.glob("examples/*.jpg"),
94
  inputs=img_input,
95
  outputs=img_output,
96
+ fn=inference,
97
  cache_examples=True,
98
  )
99