import os import json import torch import shutil from datetime import datetime from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TextIteratorStreamer from transformers.generation import LogitsProcessor import huggingface_hub from huggingface_hub import Repository from threading import Thread import gradio as gr print(f"Starting to load the model to memory") tokenizer = AutoTokenizer.from_pretrained("nort5_en-no_base") cls_index = tokenizer.convert_tokens_to_ids("[CLS]") sep_index = tokenizer.convert_tokens_to_ids("[SEP]") eos_index = tokenizer.convert_tokens_to_ids("[EOS]") pad_index = tokenizer.convert_tokens_to_ids("[PAD]") eng_index = tokenizer.convert_tokens_to_ids(">>eng<<") nob_index = tokenizer.convert_tokens_to_ids(">>nob<<") nno_index = tokenizer.convert_tokens_to_ids(">>nno<<") model = AutoModelForSeq2SeqLM.from_pretrained("nort5_en-no_base", trust_remote_code=True) device = "cuda" if torch.cuda.is_available() else "cpu" print(f"SYSTEM: Running on {device}", flush=True) model = model.to(device) model.eval() print(f"Sucessfully loaded the model to the memory") LANGUAGES = [ "🇬🇧 English", "🇳🇴 Norwegian (Bokmål)", "🇳🇴 Norwegian (Nynorsk)" ] LANGUAGE_IDS = { "🇬🇧 English": eng_index, "🇳🇴 Norwegian (Bokmål)": nob_index, "🇳🇴 Norwegian (Nynorsk)": nno_index } STATS_REPO = "https://huggingface.co/datasets/ltg/usage_statistics" HF_TOKEN = os.environ.get("HF_TOKEN") dataset = Repository( local_dir="data", clone_from=STATS_REPO, use_auth_token=HF_TOKEN ) # log the timestamp of the query def add_anonymous_usage_log(path): global dataset try: dataset.git_pull() with open(path, "a") as f: line = json.dumps(str(datetime.now()), ensure_ascii=False) f.write(f"{line}\n") dataset.push_to_hub(blocking=False) except: shutil.rmtree("data") dataset = Repository( local_dir="data", clone_from=STATS_REPO, use_auth_token=HF_TOKEN ) with open(path, "a") as f: line = json.dumps(str(datetime.now()), ensure_ascii=False) f.write(f"{line}\n") dataset.push_to_hub(blocking=False) class BatchStreamer(TextIteratorStreamer): def put(self, value): print(value.shape) #if value.size(0) == 1: # return super().put(value) if len(self.token_cache) == 0: self.token_cache = [[] for _ in range(value.size(0))] value = value.tolist() # Add the new token to the cache and decodes the entire thing. for c, v in zip(self.token_cache, value): c += [v] if isinstance(v, int) else v paragraphs = [tokenizer.decode(c, **self.decode_kwargs).strip() for c in self.token_cache] text = '\n'.join(paragraphs) self.on_finalized_text(text) def end(self): if len(self.token_cache) > 0: paragraphs = [tokenizer.decode(c, **self.decode_kwargs).strip() for c in self.token_cache] printable_text = '\n'.join(paragraphs) self.token_cache = [] self.print_len = 0 else: printable_text = "" self.next_tokens_are_prompt = True self.on_finalized_text(printable_text, stream_end=True) class RepetitionPenaltyLogitsProcessor(LogitsProcessor): def __init__(self, penalty: float, model): last_bias = model.classifier.nonlinearity[-1].bias.data last_bias = torch.nn.functional.log_softmax(last_bias) self.penalty = penalty * (last_bias - last_bias.max()) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: penalized_score = torch.gather(scores + self.penalty.unsqueeze(0).to(input_ids.device), 1, input_ids) scores.scatter_(1, input_ids, penalized_score) return scores def translate(source, source_language, target_language): if source_language == target_language: yield source.strip() return source.strip() source = [s.strip() for s in source.split('\n')] source_subwords = tokenizer(source).input_ids source_subwords = [[cls_index, LANGUAGE_IDS[target_language], LANGUAGE_IDS[source_language]] + s + [sep_index] for s in source_subwords] source_subwords = [torch.tensor(s) for s in source_subwords] source_subwords = torch.nn.utils.rnn.pad_sequence(source_subwords, batch_first=True, padding_value=pad_index) source_subwords = source_subwords[:, :512].to(device) streamer = BatchStreamer(tokenizer, timeout=60.0, skip_special_tokens=True) def generate(model, **kwargs): with torch.inference_mode(): with torch.autocast(enabled=device != "cpu", device_type=device, dtype=torch.bfloat16): return model.generate(**kwargs) generate_kwargs = dict( streamer=streamer, input_ids=source_subwords, attention_mask=(source_subwords != pad_index).long(), max_new_tokens = 512-1, #top_k=64, #top_p=0.95, #do_sample=True, #temperature=0.3, num_beams=1, #use_cache=True, logits_processor=[RepetitionPenaltyLogitsProcessor(1.0, model)], # num_beams=4, # early_stopping=True, do_sample=False, use_cache=True ) t = Thread(target=generate, args=(model,), kwargs=generate_kwargs) t.start() for new_text in streamer: yield new_text.strip() add_anonymous_usage_log("data/no-en-translation.jsonl") return new_text.strip() def switch_inputs(source, target, source_language, target_language): return target, source, target_language, source_language with gr.Blocks() as demo: # with gr.Blocks(theme='sudeepshouche/minimalist') as demo: gr.Markdown("# Norwegian-English translation") with gr.Row(): with gr.Column(scale=7, variant="panel"): source_language = gr.Dropdown( LANGUAGES, value=LANGUAGES[1], show_label=False ) source = gr.Textbox( label="Source text", placeholder="What do you want to translate?", show_label=False, lines=7, max_lines=100, autofocus=True ) # .style(container=False) submit = gr.Button("Submit", variant="primary") # .style(full_width=True) with gr.Column(scale=7, variant="panel"): target_language = gr.Dropdown( LANGUAGES, value=LANGUAGES[0], show_label=False ) target = gr.Textbox( label="Translation", show_label=False, interactive=False, lines=7, max_lines=100 ) def update_state_after_user(): return { source: gr.update(interactive=False), submit: gr.update(interactive=False), source_language: gr.update(interactive=False), target_language: gr.update(interactive=False) } def update_state_after_return(): return { source: gr.update(interactive=True), submit: gr.update(interactive=True), source_language: gr.update(interactive=True), target_language: gr.update(interactive=True) } submit_event = source.submit( fn=update_state_after_user, inputs=None, outputs=[source, submit, source_language, target_language], queue=False ).then( fn=translate, inputs=[source, source_language, target_language], outputs=[target], queue=True ).then( fn=update_state_after_return, inputs=None, outputs=[source, submit, source_language, target_language], queue=False ) submit_click_event = submit.click( fn=update_state_after_user, inputs=None, outputs=[source, submit, source_language, target_language], queue=False ).then( fn=translate, inputs=[source, source_language, target_language], outputs=[target], queue=True ).then( fn=update_state_after_return, inputs=None, outputs=[source, submit, source_language, target_language], queue=False ) demo.queue(max_size=32, concurrency_count=2) demo.launch()