smellslikeml commited on
Commit
1afbcbd
1 Parent(s): 14bb913

update app

Browse files
Files changed (1) hide show
  1. app.py +16 -4
app.py CHANGED
@@ -33,6 +33,17 @@ INTRO_TEXT = """SpaceLlama3.1 demo\n\n
33
  # Set model location as a constant outside the function
34
  MODEL_LOCATION = "remyxai/SpaceLlama3.1" # Update as needed
35
 
 
 
 
 
 
 
 
 
 
 
 
36
  def compute(image, prompt):
37
  """Runs model inference."""
38
  if image is None:
@@ -44,10 +55,8 @@ def compute(image, prompt):
44
  if isinstance(image, str):
45
  image = PIL.Image.open(image).convert("RGB")
46
 
47
- # Set device and load the model
48
- device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
49
- vlm = load(MODEL_LOCATION) # Use the constant for model location
50
- vlm.to(device, dtype=torch.bfloat16)
51
 
52
  # Prepare prompt
53
  prompt_builder = vlm.get_prompt_builder()
@@ -116,5 +125,8 @@ if __name__ == "__main__":
116
  for k, v in os.environ.items():
117
  logging.info('environ["%s"] = %r', k, v)
118
 
 
 
 
119
  create_app().queue().launch()
120
 
 
33
  # Set model location as a constant outside the function
34
  MODEL_LOCATION = "remyxai/SpaceLlama3.1" # Update as needed
35
 
36
+ # Global model variable
37
+ global_model = None
38
+
39
+ def load_model():
40
+ """Loads the model globally."""
41
+ global global_model
42
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
43
+ global_model = load(MODEL_LOCATION)
44
+ global_model.to(device, dtype=torch.bfloat16)
45
+ logging.info("Model loaded successfully.")
46
+
47
  def compute(image, prompt):
48
  """Runs model inference."""
49
  if image is None:
 
55
  if isinstance(image, str):
56
  image = PIL.Image.open(image).convert("RGB")
57
 
58
+ # Use the globally loaded model
59
+ vlm = global_model
 
 
60
 
61
  # Prepare prompt
62
  prompt_builder = vlm.get_prompt_builder()
 
125
  for k, v in os.environ.items():
126
  logging.info('environ["%s"] = %r', k, v)
127
 
128
+ # Load the model once globally
129
+ load_model()
130
+
131
  create_app().queue().launch()
132