kcarnold commited on
Commit
9e882df
·
1 Parent(s): 5c41bd3

Model selector

Browse files
Files changed (1) hide show
  1. app.py +15 -2
app.py CHANGED
@@ -11,7 +11,17 @@ st.title("Streamlit + Transformers")
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
 
13
  from transformers import MarianMTModel, MarianTokenizer
14
- model_name = 'Helsinki-NLP/opus-mt-ROMANCE-en'
 
 
 
 
 
 
 
 
 
 
15
 
16
  @st.experimental_singleton
17
  def get_tokenizer(model_name):
@@ -26,6 +36,9 @@ def get_model(model_name):
26
  tokenizer = get_tokenizer(model_name)
27
  model = get_model(model_name)
28
 
 
 
 
29
 
30
  input_text = st.text_input("Enter text to translate", "Hola, mi nombre es Juan")
31
  input_text = input_text.strip()
@@ -52,7 +65,7 @@ with torch.no_grad():
52
 
53
  last_token_logits = model_output.logits[0, -1].cpu()
54
  assert len(last_token_logits.shape) == 1
55
- most_likely_tokens = last_token_logits.topk(k=5)
56
 
57
  probs = last_token_logits.softmax(dim=-1)
58
  probs_for_likely_tokens = probs[most_likely_tokens.indices]
 
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
 
13
  from transformers import MarianMTModel, MarianTokenizer
14
+
15
+ model_name = st.radio("Select a model", [
16
+ 'Helsinki-NLP/opus-mt-roa-en',
17
+ 'Helsinki-NLP/opus-mt-en-roa',
18
+ 'other'
19
+ ])
20
+
21
+ if model_name == 'other':
22
+ model_name = st.text_input("Enter model name", 'Helsinki-NLP/opus-mt-ROMANCE-en')
23
+
24
+
25
 
26
  @st.experimental_singleton
27
  def get_tokenizer(model_name):
 
36
  tokenizer = get_tokenizer(model_name)
37
  model = get_model(model_name)
38
 
39
+ if tokenizer.supported_language_codes is not None:
40
+ st.write(f"Supported languages: {tokenizer.supported_language_codes}")
41
+
42
 
43
  input_text = st.text_input("Enter text to translate", "Hola, mi nombre es Juan")
44
  input_text = input_text.strip()
 
65
 
66
  last_token_logits = model_output.logits[0, -1].cpu()
67
  assert len(last_token_logits.shape) == 1
68
+ most_likely_tokens = last_token_logits.topk(k=20)
69
 
70
  probs = last_token_logits.softmax(dim=-1)
71
  probs_for_likely_tokens = probs[most_likely_tokens.indices]