import random import os import streamlit as st import torch from transformers import pipeline, set_seed from transformers import AutoTokenizer, AutoModelForCausalLM HF_AUTH_TOKEN = os.environ.get("HF_AUTH_TOKEN", None) DEVICE = os.environ.get("DEVICE", "cpu") # cuda:0 DTYPE = torch.float32 if DEVICE == "cpu" else torch.float16 MODEL_NAME = os.environ.get("MODEL_NAME", "bertin-project/bertin-gpt-j-6B") MAX_LENGTH = int(os.environ.get("MAX_LENGTH", 1024)) HEADER_INFO = """ # BERTIN-GPT-J-6B Spanish BERTIN GPT-J-6B Model. """.strip() SIDEBAR_INFO = """ # Configuration """.strip() PROMPT_BOX = "Introduzca su texto..." EXAMPLES = [ "¿Cuál es la capital de Francia? Respuesta:", ] def style(): st.markdown(""" """, unsafe_allow_html=True) class Normalizer: def remove_repetitions(self, text): """Remove repetitions""" first_ocurrences = [] for sentence in text.split("."): if sentence not in first_ocurrences: first_ocurrences.append(sentence) return '.'.join(first_ocurrences) def trim_last_sentence(self, text): """Trim last sentence if incomplete""" return text[:text.rfind(".") + 1] def clean_txt(self, text): return self.trim_last_sentence(self.remove_repetitions(text)) class TextGeneration: def __init__(self): self.tokenizer = None self.generator = None self.task = "text-generation" self.model_name_or_path = MODEL_NAME set_seed(42) def load(self): self.tokenizer = AutoTokenizer.from_pretrained( self.model_name_or_path, use_auth_token=HF_AUTH_TOKEN if HF_AUTH_TOKEN else None, ) self.model = AutoModelForCausalLM.from_pretrained( self.model_name_or_path, use_auth_token=HF_AUTH_TOKEN if HF_AUTH_TOKEN else None, pad_token_id=self.tokenizer.eos_token_id, eos_token_id=self.tokenizer.eos_token_id, torch_dtype=DTYPE, low_cpu_mem_usage=False if DEVICE == "cpu" else True ).to(device=DEVICE, non_blocking=True) _ = self.model.eval() device_number = -1 if DEVICE == "cpu" else int(DEVICE.split(":")[-1]) self.generator = pipeline(self.task, model=self.model, tokenizer=self.tokenizer, device=device_number) # with torch.no_grad(): # tokens = tokenizer.encode(prompt, return_tensors='pt').to(device=device, non_blocking=True) # gen_tokens = self.model.generate(tokens, do_sample=True, temperature=0.8, max_length=128) # generated = tokenizer.batch_decode(gen_tokens)[0] # return generated def generate(self, prompt, generation_kwargs): max_length = len(self.tokenizer(prompt)["input_ids"]) + generation_kwargs["max_length"] generation_kwargs["max_length"] = min(max_length, self.model.config.n_positions) # generation_kwargs["num_return_sequences"] = 1 # generation_kwargs["return_full_text"] = False return self.generator( prompt, **generation_kwargs, )[0]["generated_text"] @st.cache(allow_output_mutation=True) def load_text_generator(): generator = TextGeneration() generator.load() return generator def main(): st.set_page_config( page_title="BERTIN-GPT-J-6B", page_icon="🇪🇸", layout="wide", initial_sidebar_state="expanded" ) style() with st.spinner('Cargando el modelo. Por favor, espere...'): generator = load_text_generator() st.sidebar.markdown(SIDEBAR_INFO) max_length = st.sidebar.slider( label='Longitud máxima', help="Número máximo (aproximado) de palabras a generar.", min_value=1, max_value=MAX_LENGTH, value=50, step=1 ) top_k = st.sidebar.slider( label='Top-k', help="Número de palabras con alta probabilidad a mantener para el filtrado `top-k`", min_value=40, max_value=80, value=50, step=1 ) top_p = st.sidebar.slider( label='Top-p', help="Solo las palabras más probables con probabilidades que sumen `top_p` o más se mantienen para la generación.", min_value=0.0, max_value=1.0, value=0.95, step=0.01 ) temperature = st.sidebar.slider( label='Temperatura', help="Valor utilizado para modular las probabilidades de las siguientes palabras generadas.", min_value=0.1, max_value=10.0, value=0.8, step=0.05 ) do_sample = st.sidebar.selectbox( label='¿Muestrear?', options=(True, False), help="Si no se muestrea se usará una decodificación voraz (_greedy_).", ) do_clean = st.sidebar.selectbox( label='¿Limpiar texto?', options=(True, False), help="Si eliminar o no las palabras repetidas y recortar las últimas frases sin terminar.", ) generation_kwargs = { "max_length": max_length, "top_k": top_k, "top_p": top_p, "temperature": temperature, "do_sample": do_sample, "do_clean": do_clean, } st.markdown(HEADER_INFO) prompts = EXAMPLES + ["Personalizado"] prompt = st.selectbox('Ejemplos', prompts, index=len(prompts) - 1) if prompt == "Personalizado": prompt_box = PROMPT_BOX else: prompt_box = prompt text = st.text_area("Texto", prompt_box) generation_kwargs_ph = st.empty() cleaner = Normalizer() if st.button("¡Generar!"): with st.spinner(text="Generando..."): generation_kwargs_ph.markdown(", ".join([f"`{k}`: {v}" for k, v in generation_kwargs.items()])) if text: generated_text = generator.generate(text, generation_kwargs) if do_clean: generated_text = cleaner.clean_txt(generated_text) if generated_text.strip().startswith(text): generated_text = generated_text.replace(text, "", 1).strip() st.markdown( f'
' f'{text} ' f'{generated_text}' f'
', unsafe_allow_html=True ) if __name__ == '__main__': main()