kcarnold commited on
Commit
abc9e3b
1 Parent(s): 9e882df

Make it fun

Browse files
Files changed (1) hide show
  1. app.py +27 -8
app.py CHANGED
@@ -6,8 +6,6 @@ import torch.nn.functional as F
6
  import transformers
7
  import pandas as pd
8
 
9
- st.title("Streamlit + Transformers")
10
-
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
 
13
  from transformers import MarianMTModel, MarianTokenizer
@@ -21,13 +19,15 @@ model_name = st.radio("Select a model", [
21
  if model_name == 'other':
22
  model_name = st.text_input("Enter model name", 'Helsinki-NLP/opus-mt-ROMANCE-en')
23
 
 
 
24
 
25
 
26
- @st.experimental_singleton
27
  def get_tokenizer(model_name):
28
  return MarianTokenizer.from_pretrained(model_name)
29
 
30
- @st.experimental_singleton
31
  def get_model(model_name):
32
  model = MarianMTModel.from_pretrained(model_name).to(device)
33
  print(f"Loaded model, {model.num_parameters():,d} parameters.")
@@ -36,8 +36,10 @@ def get_model(model_name):
36
  tokenizer = get_tokenizer(model_name)
37
  model = get_model(model_name)
38
 
39
- if tokenizer.supported_language_codes is not None:
40
- st.write(f"Supported languages: {tokenizer.supported_language_codes}")
 
 
41
 
42
 
43
  input_text = st.text_input("Enter text to translate", "Hola, mi nombre es Juan")
@@ -45,10 +47,22 @@ input_text = input_text.strip()
45
  if not input_text:
46
  st.stop()
47
 
 
 
 
 
48
  output_so_far = st.text_input("Enter text translated so far", "Hello, my")
49
 
50
  input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
51
 
 
 
 
 
 
 
 
 
52
  # tokenize the output so far
53
  with tokenizer.as_target_tokenizer():
54
  output_tokens = tokenizer.tokenize(output_so_far)
@@ -62,7 +76,6 @@ with torch.no_grad():
62
  input_ids = input_ids,
63
  decoder_input_ids = torch.tensor([decoder_input_ids]).to(device))
64
 
65
-
66
  last_token_logits = model_output.logits[0, -1].cpu()
67
  assert len(last_token_logits.shape) == 1
68
  most_likely_tokens = last_token_logits.topk(k=20)
@@ -79,5 +92,11 @@ with tokenizer.as_target_tokenizer():
79
  'cumulative probability': probs_for_likely_tokens.cumsum(0)
80
  })
81
 
82
-
83
  st.write(probs_table)
 
 
 
 
 
 
 
 
6
  import transformers
7
  import pandas as pd
8
 
 
 
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
 
11
  from transformers import MarianMTModel, MarianTokenizer
 
19
  if model_name == 'other':
20
  model_name = st.text_input("Enter model name", 'Helsinki-NLP/opus-mt-ROMANCE-en')
21
 
22
+ if not hasattr(st, "cache_resource"):
23
+ st.cache_resource = st.experimental_singleton
24
 
25
 
26
+ @st.cache_resource
27
  def get_tokenizer(model_name):
28
  return MarianTokenizer.from_pretrained(model_name)
29
 
30
+ @st.cache_resource
31
  def get_model(model_name):
32
  model = MarianMTModel.from_pretrained(model_name).to(device)
33
  print(f"Loaded model, {model.num_parameters():,d} parameters.")
 
36
  tokenizer = get_tokenizer(model_name)
37
  model = get_model(model_name)
38
 
39
+ if tokenizer.supported_language_codes:
40
+ lang_code = st.selectbox("Select a language", tokenizer.supported_language_codes)
41
+ else:
42
+ lang_code = None
43
 
44
 
45
  input_text = st.text_input("Enter text to translate", "Hola, mi nombre es Juan")
 
47
  if not input_text:
48
  st.stop()
49
 
50
+ # prepend the language code if necessary
51
+ if lang_code:
52
+ input_text = f"{lang_code} {input_text}"
53
+
54
  output_so_far = st.text_input("Enter text translated so far", "Hello, my")
55
 
56
  input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
57
 
58
+ example_generations = model.generate(
59
+ input_ids,
60
+ num_beams=4,
61
+ num_return_sequences=4,
62
+ )
63
+ st.write("Example generations:")
64
+ st.write(tokenizer.batch_decode(example_generations, skip_special_tokens=True))
65
+
66
  # tokenize the output so far
67
  with tokenizer.as_target_tokenizer():
68
  output_tokens = tokenizer.tokenize(output_so_far)
 
76
  input_ids = input_ids,
77
  decoder_input_ids = torch.tensor([decoder_input_ids]).to(device))
78
 
 
79
  last_token_logits = model_output.logits[0, -1].cpu()
80
  assert len(last_token_logits.shape) == 1
81
  most_likely_tokens = last_token_logits.topk(k=20)
 
92
  'cumulative probability': probs_for_likely_tokens.cumsum(0)
93
  })
94
 
 
95
  st.write(probs_table)
96
+
97
+ loss_table = pd.DataFrame({
98
+ 'token': [tokenizer.decode(token_id) for token_id in decoder_input_ids[1:]],
99
+ 'loss': F.cross_entropy(model_output.logits[0, :-1], torch.tensor(decoder_input_ids[1:]).to(device), reduction='none').cpu()
100
+ })
101
+ st.write(loss_table)
102
+ st.write("Total loss so far:", loss_table.loss.sum())