bryanmildort commited on
Commit
466be16
·
1 Parent(s): cf9b772

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -5
app.py CHANGED
@@ -13,7 +13,7 @@ st.markdown("<h6 style='text-align: center; color: #489DDB;'>by Bryan Mildort</h
13
 
14
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
15
  # from accelerate import infer_auto_device_map
16
- # device = "cuda:0" if torch.cuda.is_available() else "cpu"
17
  # device_str = f"""Device being used: {device}"""
18
  # st.write(device_str)
19
  # device_map = infer_auto_device_map(model, dtype="float16")
@@ -21,10 +21,11 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
21
 
22
  @st.cache(allow_output_mutation=True)
23
  def load_model():
24
- model = pipeline("text-generation", model="bryanmildort/gpt_neo_notes")
25
- return model
26
-
27
- # model = model.to(device)
 
28
 
29
  pipe = load_model()
30
 
 
13
 
14
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
15
  # from accelerate import infer_auto_device_map
16
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
17
  # device_str = f"""Device being used: {device}"""
18
  # st.write(device_str)
19
  # device_map = infer_auto_device_map(model, dtype="float16")
 
21
 
22
  @st.cache(allow_output_mutation=True)
23
  def load_model():
24
+ model = AutoModelForCausalLM.from_pretrained("bryanmildort/gpt_neo_notes", low_cpu_mem_usage=True)
25
+ model.to(device)
26
+ tokenizer = AutoTokenizer.from_pretrained("bryanmildort/gpt_neo_notes")
27
+ pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
28
+ return pipe
29
 
30
  pipe = load_model()
31