File size: 4,387 Bytes
65b683f
 
325ca0f
 
 
65b683f
 
 
 
 
 
 
 
 
9e882df
325ca0f
9e882df
 
325ca0f
9e882df
325ca0f
9e882df
325ca0f
 
9e882df
325ca0f
 
65b683f
abc9e3b
65b683f
 
 
abc9e3b
65b683f
5c41bd3
 
 
65b683f
 
 
206be88
abc9e3b
 
 
 
9e882df
65b683f
325ca0f
 
 
65b683f
 
 
 
abc9e3b
 
 
 
65b683f
 
 
abc9e3b
 
 
 
6d1408f
abc9e3b
325ca0f
 
 
 
 
 
 
 
 
 
 
 
abc9e3b
65b683f
 
 
 
 
 
 
 
 
 
 
 
 
6d1408f
 
 
 
 
 
 
65b683f
 
6d1408f
 
 
65b683f
6d1408f
65b683f
 
 
6d1408f
65b683f
6d1408f
 
 
 
 
 
 
 
 
65b683f
47f3f6e
325ca0f
abc9e3b
47f3f6e
325ca0f
47f3f6e
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import streamlit as st

if not hasattr(st, "cache_resource"):
    st.cache_resource = st.experimental_singleton

import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers
import pandas as pd

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

from transformers import MarianMTModel, MarianTokenizer

model_options = [
    'Helsinki-NLP/opus-mt-roa-en',
    'Helsinki-NLP/opus-mt-en-roa',
]

col1, col2 = st.columns(2)

with col1:
    model_name = st.selectbox("Select a model", model_options + ['other'])

    if model_name == 'other':
        model_name = st.text_input("Enter model name", model_options[0])

@st.cache_resource
def get_tokenizer(model_name):
    return MarianTokenizer.from_pretrained(model_name)

@st.cache_resource
def get_model(model_name):
    model = MarianMTModel.from_pretrained(model_name).to(device)
    print(f"Loaded model, {model.num_parameters():,d} parameters.")
    return model

tokenizer = get_tokenizer(model_name)
model = get_model(model_name)

if tokenizer.supported_language_codes:
    lang_code = st.selectbox("Select a language", tokenizer.supported_language_codes)
else:
    lang_code = None


with col2:
    input_text = st.text_input("Enter text to translate", "Hola, mi nombre es Juan")

input_text = input_text.strip()
if not input_text:
    st.stop()

# prepend the language code if necessary
if lang_code:
    input_text = f"{lang_code} {input_text}"


input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)

example_generations = model.generate(
    input_ids,
    num_beams=4,
    num_return_sequences=4,
    max_length=100,
)

col1, col2 = st.columns(2)
with col1:
    st.write("Example generations:")
    st.write('\n'.join(
        '- ' + translation
        for translation in tokenizer.batch_decode(example_generations, skip_special_tokens=True)))

with col2:
    example_first_word = tokenizer.decode(example_generations[0, 1])
    output_so_far = st.text_input("Enter text translated so far", example_first_word)


# tokenize the output so far
with tokenizer.as_target_tokenizer():
    output_tokens = tokenizer.tokenize(output_so_far)
    decoder_input_ids = tokenizer.convert_tokens_to_ids(output_tokens)

# Add the start token
decoder_input_ids = [model.config.decoder_start_token_id] + decoder_input_ids

with torch.no_grad():
    model_output = model(
        input_ids = input_ids,
        decoder_input_ids = torch.tensor([decoder_input_ids]).to(device))

with st.expander("Configuration"):
    top_k = st.slider("Number of tokens to show", min_value=1, max_value=100, value=5)
    temperature = st.slider("Temperature", min_value=0.0, max_value=2.0, value=1.0, step=0.01)
    show_token_ids = st.checkbox("Show token IDs", value=False)
    show_logprobs = st.checkbox("Show log probabilities", value=False)
    show_cumulative_probs = st.checkbox("Show cumulative probabilities", value=False)

last_token_logits = model_output.logits[0, -1].cpu()
assert len(last_token_logits.shape) == 1
# apply temperature
last_token_logits_with_temperature = last_token_logits / temperature
most_likely_tokens = last_token_logits.topk(k=top_k)

probs = last_token_logits_with_temperature.softmax(dim=-1)
probs_for_likely_tokens = probs[most_likely_tokens.indices]

with tokenizer.as_target_tokenizer():
    prob_table_data = {
        'token': [tokenizer.decode(token_id) for token_id in most_likely_tokens.indices],
    }
    if show_token_ids:
        prob_table_data['id'] = most_likely_tokens.indices
    prob_table_data['probability'] = probs_for_likely_tokens
    if show_logprobs:
        prob_table_data['logprob'] = last_token_logits.log_softmax(dim=-1)[most_likely_tokens.indices]
    if show_cumulative_probs:
        prob_table_data['cumulative probability'] = probs_for_likely_tokens.cumsum(0)
    probs_table = pd.DataFrame(prob_table_data)

st.subheader("Most likely next tokens")
st.table(probs_table.style.hide(axis='index'))

if len(decoder_input_ids) > 1:
    st.subheader("Loss by already-generated token")
    loss_table = pd.DataFrame({
        'token': [tokenizer.decode(token_id) for token_id in decoder_input_ids[1:]],
        'loss': F.cross_entropy(model_output.logits[0, :-1], torch.tensor(decoder_input_ids[1:]).to(device), reduction='none').cpu()
    })
    st.write(loss_table)
    st.write("Total loss so far:", loss_table.loss.sum())