rawc0der commited on
Commit
96db79f
·
1 Parent(s): b9620be

change model context

Browse files
Files changed (1) hide show
  1. app.py +5 -3
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import spaces
 
2
  from fastapi import FastAPI, HTTPException, UploadFile, File
3
  from typing import Optional, Dict, Any
4
  import torch
@@ -123,8 +124,8 @@ class ModelManager:
123
  if hasattr(self.current_pipeline, 'enable_xformers_memory_efficient_attention'):
124
  self.current_pipeline.enable_xformers_memory_efficient_attention()
125
 
126
- if self._device == "cuda":
127
- self.current_pipeline.enable_model_cpu_offload()
128
 
129
  self.current_model = model_name
130
 
@@ -159,7 +160,8 @@ class ModelContext:
159
  pipeline = model_manager.load_model(self.model_name)
160
  if hasattr(pipeline, 'reset_device_map'):
161
  pipeline.reset_device_map()
162
- pipeline.enable_model_cpu_offload()
 
163
  return pipeline
164
 
165
  def __exit__(self, exc_type, exc_val, exc_tb):
 
1
  import spaces
2
+ from accelerate import dispatch_model
3
  from fastapi import FastAPI, HTTPException, UploadFile, File
4
  from typing import Optional, Dict, Any
5
  import torch
 
124
  if hasattr(self.current_pipeline, 'enable_xformers_memory_efficient_attention'):
125
  self.current_pipeline.enable_xformers_memory_efficient_attention()
126
 
127
+ # if self._device == "cuda":
128
+ # self.current_pipeline.enable_model_cpu_offload()
129
 
130
  self.current_model = model_name
131
 
 
160
  pipeline = model_manager.load_model(self.model_name)
161
  if hasattr(pipeline, 'reset_device_map'):
162
  pipeline.reset_device_map()
163
+ # Enable automatic device mapping strategy
164
+ dispatch_model(pipeline)
165
  return pipeline
166
 
167
  def __exit__(self, exc_type, exc_val, exc_tb):