Spaces:
Sleeping
Sleeping
File size: 4,387 Bytes
65b683f 325ca0f 65b683f 9e882df 325ca0f 9e882df 325ca0f 9e882df 325ca0f 9e882df 325ca0f 9e882df 325ca0f 65b683f abc9e3b 65b683f abc9e3b 65b683f 5c41bd3 65b683f 206be88 abc9e3b 9e882df 65b683f 325ca0f 65b683f abc9e3b 65b683f abc9e3b 6d1408f abc9e3b 325ca0f abc9e3b 65b683f 6d1408f 65b683f 6d1408f 65b683f 6d1408f 65b683f 6d1408f 65b683f 6d1408f 65b683f 47f3f6e 325ca0f abc9e3b 47f3f6e 325ca0f 47f3f6e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import streamlit as st
if not hasattr(st, "cache_resource"):
st.cache_resource = st.experimental_singleton
import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers
import pandas as pd
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from transformers import MarianMTModel, MarianTokenizer
model_options = [
'Helsinki-NLP/opus-mt-roa-en',
'Helsinki-NLP/opus-mt-en-roa',
]
col1, col2 = st.columns(2)
with col1:
model_name = st.selectbox("Select a model", model_options + ['other'])
if model_name == 'other':
model_name = st.text_input("Enter model name", model_options[0])
@st.cache_resource
def get_tokenizer(model_name):
return MarianTokenizer.from_pretrained(model_name)
@st.cache_resource
def get_model(model_name):
model = MarianMTModel.from_pretrained(model_name).to(device)
print(f"Loaded model, {model.num_parameters():,d} parameters.")
return model
tokenizer = get_tokenizer(model_name)
model = get_model(model_name)
if tokenizer.supported_language_codes:
lang_code = st.selectbox("Select a language", tokenizer.supported_language_codes)
else:
lang_code = None
with col2:
input_text = st.text_input("Enter text to translate", "Hola, mi nombre es Juan")
input_text = input_text.strip()
if not input_text:
st.stop()
# prepend the language code if necessary
if lang_code:
input_text = f"{lang_code} {input_text}"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
example_generations = model.generate(
input_ids,
num_beams=4,
num_return_sequences=4,
max_length=100,
)
col1, col2 = st.columns(2)
with col1:
st.write("Example generations:")
st.write('\n'.join(
'- ' + translation
for translation in tokenizer.batch_decode(example_generations, skip_special_tokens=True)))
with col2:
example_first_word = tokenizer.decode(example_generations[0, 1])
output_so_far = st.text_input("Enter text translated so far", example_first_word)
# tokenize the output so far
with tokenizer.as_target_tokenizer():
output_tokens = tokenizer.tokenize(output_so_far)
decoder_input_ids = tokenizer.convert_tokens_to_ids(output_tokens)
# Add the start token
decoder_input_ids = [model.config.decoder_start_token_id] + decoder_input_ids
with torch.no_grad():
model_output = model(
input_ids = input_ids,
decoder_input_ids = torch.tensor([decoder_input_ids]).to(device))
with st.expander("Configuration"):
top_k = st.slider("Number of tokens to show", min_value=1, max_value=100, value=5)
temperature = st.slider("Temperature", min_value=0.0, max_value=2.0, value=1.0, step=0.01)
show_token_ids = st.checkbox("Show token IDs", value=False)
show_logprobs = st.checkbox("Show log probabilities", value=False)
show_cumulative_probs = st.checkbox("Show cumulative probabilities", value=False)
last_token_logits = model_output.logits[0, -1].cpu()
assert len(last_token_logits.shape) == 1
# apply temperature
last_token_logits_with_temperature = last_token_logits / temperature
most_likely_tokens = last_token_logits.topk(k=top_k)
probs = last_token_logits_with_temperature.softmax(dim=-1)
probs_for_likely_tokens = probs[most_likely_tokens.indices]
with tokenizer.as_target_tokenizer():
prob_table_data = {
'token': [tokenizer.decode(token_id) for token_id in most_likely_tokens.indices],
}
if show_token_ids:
prob_table_data['id'] = most_likely_tokens.indices
prob_table_data['probability'] = probs_for_likely_tokens
if show_logprobs:
prob_table_data['logprob'] = last_token_logits.log_softmax(dim=-1)[most_likely_tokens.indices]
if show_cumulative_probs:
prob_table_data['cumulative probability'] = probs_for_likely_tokens.cumsum(0)
probs_table = pd.DataFrame(prob_table_data)
st.subheader("Most likely next tokens")
st.table(probs_table.style.hide(axis='index'))
if len(decoder_input_ids) > 1:
st.subheader("Loss by already-generated token")
loss_table = pd.DataFrame({
'token': [tokenizer.decode(token_id) for token_id in decoder_input_ids[1:]],
'loss': F.cross_entropy(model_output.logits[0, :-1], torch.tensor(decoder_input_ids[1:]).to(device), reduction='none').cpu()
})
st.write(loss_table)
st.write("Total loss so far:", loss_table.loss.sum()) |