import torch import gradio as gr from optimum.onnxruntime import ORTModelForCausalLM from transformers import AutoTokenizer # https://huggingface.co/collections/p1atdev/dart-v2-danbooru-tags-transformer-v2-66291115701b6fe773399b0a model_id = "p1atdev/dart-v2-sft" model = ORTModelForCausalLM.from_pretrained(model_id) tokenizer = AutoTokenizer.from_pretrained(model_id) tokenizer_with_prefix_space = AutoTokenizer.from_pretrained(model_id, add_prefix_space=True) # https://huggingface.co/docs/transformers/v4.44.2/en/internal/generation_utils#transformers.NoBadWordsLogitsProcessor def get_tokens_as_list(word_list): "Converts a sequence of words into a list of tokens" tokens_list = [] for word in word_list: tokenized_word = tokenizer_with_prefix_space([word], add_special_tokens=False).input_ids[0] tokens_list.append(tokenized_word) return tokens_list def generate_tags(general_tags: str): # https://huggingface.co/p1atdev/dart-v2-sft#prompt-format general_tags = ",".join(tag.strip() for tag in general_tags.split(",") if tag) prompt = ( "<|bos|>" # "" # "" "<|rating:general|><|aspect_ratio:tall|><|length:long|>" f"{general_tags}<|identity:none|><|input_end|>" ) inputs = tokenizer(prompt, return_tensors="pt").input_ids # bad_words_ids = get_tokens_as_list(word_list=[""]) with torch.no_grad(): outputs = model.generate( inputs, do_sample=True, temperature=1.0, top_p=1.0, top_k=100, max_new_tokens=128, num_beams=1, # bad_words_ids=bad_words_ids, ) return ", ".join( [tag for tag in tokenizer.batch_decode(outputs[0], skip_special_tokens=True) if tag.strip() != ""] ) demo = gr.Interface( fn=generate_tags, inputs=gr.TextArea("1girl, black hair", lines=4), outputs=gr.Textbox(show_copy_button=True), clear_btn=None, analytics_enabled=False, ) demo.launch()