mikeee commited on
Commit
8bd8f70
·
1 Parent(s): 7b642fe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -9
app.py CHANGED
@@ -26,16 +26,24 @@ model = None
26
  gc.collect()
27
 
28
  logger.info("start")
 
29
 
30
- model = AutoModelForCausalLM.from_pretrained(
31
- "model", # loc
32
- device_map="auto",
33
- torch_dtype=torch.bfloat16,
34
- load_in_8bit=True,
35
- trust_remote_code=True,
36
- # use_ram_optimized_load=False,
37
- # offload_folder="offload_folder",
38
- )
 
 
 
 
 
 
 
39
 
40
  rich.print(f"{model=}")
41
 
 
26
  gc.collect()
27
 
28
  logger.info("start")
29
+ has_cuda = torch.cuda.is_available()
30
 
31
+ if has_cuda:
32
+ model = AutoModelForCausalLM.from_pretrained(
33
+ "model", # loc
34
+ # device_map="auto",
35
+ torch_dtype=torch.bfloat16,
36
+ load_in_8bit=True,
37
+ trust_remote_code=True,
38
+ # use_ram_optimized_load=False,
39
+ # offload_folder="offload_folder",
40
+ ).cuda()
41
+ else:
42
+ model = (
43
+ AutoModel.from_pretrained(model_name, trust_remote_code=True).float()
44
+ ).float()
45
+
46
+ model = model.eval()
47
 
48
  rich.print(f"{model=}")
49