supersolar commited on
Commit
b22cd65
·
verified ·
1 Parent(s): c803a91

Update utils/florence.py

Browse files
Files changed (1) hide show
  1. utils/florence.py +6 -7
utils/florence.py CHANGED
@@ -27,14 +27,13 @@ def fixed_get_imports(filename: Union[str, os.PathLike]) -> list[str]:
27
  def load_florence_model(
28
  device: torch.device, checkpoint: str = FLORENCE_CHECKPOINT
29
  ) -> Tuple[Any, Any]:
30
- with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
31
- model = AutoModelForCausalLM.from_pretrained(
32
- checkpoint, trust_remote_code=True).to(device).eval()
33
- processor = AutoProcessor.from_pretrained(
34
- checkpoint, trust_remote_code=True)
35
- return model, processor
36
-
37
 
 
38
  def run_florence_inference(
39
  model: Any,
40
  processor: Any,
 
27
  def load_florence_model(
28
  device: torch.device, checkpoint: str = FLORENCE_CHECKPOINT
29
  ) -> Tuple[Any, Any]:
30
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
31
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
32
+ model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-base", torch_dtype=torch_dtype, trust_remote_code=True).to(device)
33
+ processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
34
+ return model, processor
 
 
35
 
36
+
37
  def run_florence_inference(
38
  model: Any,
39
  processor: Any,