import gradio as gr from transformers import DebertaV2Tokenizer, DebertaV2ForTokenClassification import torch model_name = "PleIAs/Segmentext" tokenizer = DebertaV2Tokenizer.from_pretrained(model_name) model = DebertaV2ForTokenClassification.from_pretrained(model_name) id2label = { 0: "author", 1: "bibliography", 2: "caption", 3: "contact", 4: "date", 5: "dialog", 6: "footnote", 7: "keywords", 8: "math", 9: "paratext", 10: "separator", 11: "table", 12: "text", 13: "title" } color_map = { "author": "blue", "bibliography": "purple", "caption": "orange", "contact": "cyan", "date": "green", "dialog": "yellow", "footnote": "pink", "keywords": "lightblue", "math": "red", "paratext": "lightgreen", "separator": "gray", "table": "brown", "text": "lightgray", "title": "gold" } def segment_text(input_text): tokens = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True) with torch.no_grad(): outputs = model(**tokens) logits = outputs.logits predictions = torch.argmax(logits, dim=-1).squeeze().tolist() tokens_decoded = tokenizer.convert_ids_to_tokens(tokens['input_ids'].squeeze()) segments = [] current_word = "" for token, label_id in zip(tokens_decoded, predictions): if token.startswith("▁"): # handling wordpieces, specific to some tokenizers if current_word: segments.append((current_word, id2label[label_id])) current_word = token.replace("▁", "") # new word else: current_word += token # append subword part to current word if current_word: segments.append((current_word, id2label[label_id])) return segments with gr.Blocks() as demo: gr.Markdown("# PleIAs/Segmentext Text Segmentation Demo") with gr.Row(): input_text = gr.Textbox(label="Input Text", lines=5, placeholder="Enter text for segmentation") output_text = gr.HighlightedText(label="Segmented Text", color_map=color_map, combine_adjacent=True) def process(input_text): return segment_text(input_text) submit_button = gr.Button("Segment Text") submit_button.click(fn=process, inputs=input_text, outputs=output_text) if __name__ == "__main__": demo.launch()