Spaces:
Runtime error
Runtime error
import random | |
import os | |
import streamlit as st | |
import torch | |
from transformers import pipeline, set_seed | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
HF_AUTH_TOKEN = os.environ.get("HF_AUTH_TOKEN", None) | |
DEVICE = os.environ.get("DEVICE", "cpu") # cuda:0 | |
DTYPE = torch.float32 if DEVICE == "cpu" else torch.float16 | |
MODEL_NAME = os.environ.get("MODEL_NAME", "bertin-project/bertin-gpt-j-6B") | |
MAX_LENGTH = int(os.environ.get("MAX_LENGTH", 1024)) | |
HEADER_INFO = """ | |
# BERTIN-GPT-J-6B | |
Spanish BERTIN GPT-J-6B Model. | |
""".strip() | |
SIDEBAR_INFO = """ | |
# Configuration | |
""".strip() | |
PROMPT_BOX = "Introduzca su texto..." | |
EXAMPLES = [ | |
"¿Cuál es la capital de Francia? Respuesta:", | |
] | |
def style(): | |
st.markdown(""" | |
<link href="https://fonts.googleapis.com/css2?family=Roboto:wght@300&display=swap%22%20rel=%22stylesheet%22" rel="stylesheet"> | |
<style> | |
.ltr, | |
textarea { | |
font-family: Roboto !important; | |
text-align: left; | |
direction: ltr !important; | |
} | |
.ltr-box { | |
border-bottom: 1px solid #ddd; | |
padding-bottom: 20px; | |
} | |
.rtl { | |
text-align: left; | |
direction: ltr !important; | |
} | |
span.result-text { | |
padding: 3px 3px; | |
line-height: 32px; | |
} | |
span.generated-text { | |
background-color: rgb(118 200 147 / 13%); | |
} | |
</style>""", unsafe_allow_html=True) | |
class Normalizer: | |
def remove_repetitions(self, text): | |
"""Remove repetitions""" | |
first_ocurrences = [] | |
for sentence in text.split("."): | |
if sentence not in first_ocurrences: | |
first_ocurrences.append(sentence) | |
return '.'.join(first_ocurrences) | |
def trim_last_sentence(self, text): | |
"""Trim last sentence if incomplete""" | |
return text[:text.rfind(".") + 1] | |
def clean_txt(self, text): | |
return self.trim_last_sentence(self.remove_repetitions(text)) | |
class TextGeneration: | |
def __init__(self): | |
self.tokenizer = None | |
self.generator = None | |
self.task = "text-generation" | |
self.model_name_or_path = MODEL_NAME | |
set_seed(42) | |
def load(self): | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
self.model_name_or_path, use_auth_token=HF_AUTH_TOKEN if HF_AUTH_TOKEN else None, | |
) | |
self.model = AutoModelForCausalLM.from_pretrained( | |
self.model_name_or_path, use_auth_token=HF_AUTH_TOKEN if HF_AUTH_TOKEN else None, | |
pad_token_id=self.tokenizer.eos_token_id, eos_token_id=self.tokenizer.eos_token_id, | |
torch_dtype=DTYPE, low_cpu_mem_usage=False if DEVICE == "cpu" else True | |
).to(device=DEVICE, non_blocking=True) | |
_ = self.model.eval() | |
device_number = -1 if DEVICE == "cpu" else int(DEVICE.split(":")[-1]) | |
self.generator = pipeline(self.task, model=self.model, tokenizer=self.tokenizer, device=device_number) | |
# with torch.no_grad(): | |
# tokens = tokenizer.encode(prompt, return_tensors='pt').to(device=device, non_blocking=True) | |
# gen_tokens = self.model.generate(tokens, do_sample=True, temperature=0.8, max_length=128) | |
# generated = tokenizer.batch_decode(gen_tokens)[0] | |
# return generated | |
def generate(self, prompt, generation_kwargs): | |
max_length = len(self.tokenizer(prompt)["input_ids"]) + generation_kwargs["max_length"] | |
generation_kwargs["max_length"] = min(max_length, self.model.config.n_positions) | |
# generation_kwargs["num_return_sequences"] = 1 | |
# generation_kwargs["return_full_text"] = False | |
return self.generator( | |
prompt, | |
**generation_kwargs, | |
)[0]["generated_text"] | |
def load_text_generator(): | |
generator = TextGeneration() | |
generator.load() | |
return generator | |
def main(): | |
st.set_page_config( | |
page_title="BERTIN-GPT-J-6B", | |
page_icon="🇪🇸", | |
layout="wide", | |
initial_sidebar_state="expanded" | |
) | |
style() | |
with st.spinner('Cargando el modelo. Por favor, espere...'): | |
generator = load_text_generator() | |
st.sidebar.markdown(SIDEBAR_INFO) | |
max_length = st.sidebar.slider( | |
label='Longitud máxima', | |
help="Número máximo aproximado de palabras a generar).", | |
min_value=1, | |
max_value=MAX_LENGTH, | |
value=50, | |
step=1 | |
) | |
top_k = st.sidebar.slider( | |
label='Top-k', | |
help="Número de palabras con alta probabilidad a mantener para el filtrado `top-k`", | |
min_value=40, | |
max_value=80, | |
value=50, | |
step=1 | |
) | |
top_p = st.sidebar.slider( | |
label='Top-p', | |
help="Solo las palabras más probables con probabilidades que sumen `top_p` o más se mantienen para la generación.", | |
min_value=0.0, | |
max_value=1.0, | |
value=0.95, | |
step=0.01 | |
) | |
temperature = st.sidebar.slider( | |
label='Temperatura', | |
help="Valor utilizado para modular las probabilidades de las siguientes palabras generadas.", | |
min_value=0.1, | |
max_value=10.0, | |
value=0.8, | |
step=0.05 | |
) | |
do_sample = st.sidebar.selectbox( | |
label='¿Muestrear?', | |
options=(True, False), | |
help="Si no se muestrea se usará una decodificación voraz (_greedy_).", | |
) | |
do_clean = st.sidebar.selectbox( | |
label='¿Limpiar texto?', | |
options=(True, False), | |
help="Si eliminar o no las palabras repetidas y recortar las últimas frases sin terminar.", | |
) | |
generation_kwargs = { | |
"max_length": max_length, | |
"top_k": top_k, | |
"top_p": top_p, | |
"temperature": temperature, | |
"do_sample": do_sample, | |
"do_clean": do_clean, | |
} | |
st.markdown(HEADER_INFO) | |
prompts = EXAMPLES + ["Personalizado"] | |
prompt = st.selectbox('Ejemplos', prompts, index=len(prompts) - 1) | |
if prompt == "Personalizado": | |
prompt_box = PROMPT_BOX | |
else: | |
prompt_box = prompt | |
text = st.text_area("Texto", prompt_box) | |
generation_kwargs_ph = st.empty() | |
cleaner = Normalizer() | |
if st.button("¡Generar!"): | |
with st.spinner(text="Generando..."): | |
generation_kwargs_ph.markdown(", ".join([f"`{k}`: {v}" for k, v in generation_kwargs.items()])) | |
if text: | |
generated_text = generator.generate(text, generation_kwargs) | |
if do_clean: | |
generated_text = cleaner.clean_txt(generated_text) | |
if generated_text.strip().startswith(text): | |
generated_text = generated_text.replace(text, "", 1).strip() | |
st.markdown( | |
f'<p class="ltr ltr-box">' | |
f'<span class="result-text">{text} <span>' | |
f'<span class="result-text generated-text">{generated_text}</span>' | |
f'</p>', | |
unsafe_allow_html=True | |
) | |
if __name__ == '__main__': | |
main() | |