bryanmildort commited on
Commit
221291b
·
1 Parent(s): 466be16

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -19,10 +19,10 @@ device = "cuda:0" if torch.cuda.is_available() else "cpu"
19
  # device_map = infer_auto_device_map(model, dtype="float16")
20
  # st.write(device_map)
21
 
22
- @st.cache(allow_output_mutation=True)
23
  def load_model():
24
- model = AutoModelForCausalLM.from_pretrained("bryanmildort/gpt_neo_notes", low_cpu_mem_usage=True)
25
- model.to(device)
26
  tokenizer = AutoTokenizer.from_pretrained("bryanmildort/gpt_neo_notes")
27
  pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
28
  return pipe
 
19
  # device_map = infer_auto_device_map(model, dtype="float16")
20
  # st.write(device_map)
21
 
22
+ @st.cache
23
  def load_model():
24
+ model = AutoModelForCausalLM.from_pretrained("bryanmildort/gpt_neo_notes", low_cpu_mem_usage=True, load_in_8bit=True)
25
+ # model.to(device)
26
  tokenizer = AutoTokenizer.from_pretrained("bryanmildort/gpt_neo_notes")
27
  pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
28
  return pipe