kcarnold commited on
Commit
fa263f0
·
1 Parent(s): 3ff8b95

Avoid having to import all of torch if we're using the API.

Browse files
Files changed (1) hide show
  1. app.py +4 -5
app.py CHANGED
@@ -1,9 +1,4 @@
1
  import streamlit as st
2
-
3
- import torch
4
- import torch.nn as nn
5
- import torch.nn.functional as F
6
- from transformers import AutoTokenizer, AutoModelForCausalLM
7
  import pandas as pd
8
  import html
9
 
@@ -20,10 +15,12 @@ if model_name == 'other':
20
 
21
  @st.cache_resource
22
  def get_tokenizer(model_name):
 
23
  return AutoTokenizer.from_pretrained(model_name).from_pretrained(model_name)
24
 
25
  @st.cache_resource
26
  def get_model(model_name):
 
27
  model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto', torch_dtype=torch.bfloat16)
28
  print(f"Loaded model, {model.num_parameters():,d} parameters.")
29
  return model
@@ -33,6 +30,8 @@ doc = st.text_area("Document", "This is a document that I would like to have rew
33
 
34
 
35
  def get_spans_local(prompt, doc):
 
 
36
  tokenizer = get_tokenizer(model_name)
37
  model = get_model(model_name)
38
 
 
1
  import streamlit as st
 
 
 
 
 
2
  import pandas as pd
3
  import html
4
 
 
15
 
16
  @st.cache_resource
17
  def get_tokenizer(model_name):
18
+ from transformers import AutoTokenizer
19
  return AutoTokenizer.from_pretrained(model_name).from_pretrained(model_name)
20
 
21
  @st.cache_resource
22
  def get_model(model_name):
23
+ from transformers import AutoModelForCausalLM
24
  model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto', torch_dtype=torch.bfloat16)
25
  print(f"Loaded model, {model.num_parameters():,d} parameters.")
26
  return model
 
30
 
31
 
32
  def get_spans_local(prompt, doc):
33
+ import torch
34
+
35
  tokenizer = get_tokenizer(model_name)
36
  model = get_model(model_name)
37