cyber-chris commited on
Commit
36a864b
·
1 Parent(s): eef0b87

update device_map

Browse files
Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -6,10 +6,11 @@ from repl import generate_with_dms
6
  import gradio as gr
7
 
8
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
9
 
10
  hf_model = AutoModelForCausalLM.from_pretrained(
11
  "meta-llama/Meta-Llama-3-8B-Instruct",
12
- device_map="auto",
13
  torch_dtype="float16",
14
  )
15
  model = HookedSAETransformer.from_pretrained_no_processing(
 
6
  import gradio as gr
7
 
8
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
9
+ print(f"DEVICE: {DEVICE}")
10
 
11
  hf_model = AutoModelForCausalLM.from_pretrained(
12
  "meta-llama/Meta-Llama-3-8B-Instruct",
13
+ device_map="auto" if DEVICE == "cuda" else DEVICE,
14
  torch_dtype="float16",
15
  )
16
  model = HookedSAETransformer.from_pretrained_no_processing(