import gradio as gr from transformers import DebertaV2Tokenizer, DebertaV2ForTokenClassification import torch from globe import title, description, joinus, model_name, placeholder, modelinfor1, modelinfor2, modelinfor3, id2label tokenizer = DebertaV2Tokenizer.from_pretrained(model_name) model = DebertaV2ForTokenClassification.from_pretrained(model_name) color_map = { "author": "blue", "bibliography": "purple", "caption": "orange", "contact": "cyan", "date": "green", "dialog": "yellow", "footnote": "pink", "keywords": "lightblue", "math": "red", "paratext": "lightgreen", "separator": "gray", "table": "brown", "text": "lightgray", "title": "gold" } def segment_text(input_text): tokens = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True) with torch.no_grad(): outputs = model(**tokens) logits = outputs.logits predictions = torch.argmax(logits, dim=-1).squeeze().tolist() tokens_decoded = tokenizer.convert_ids_to_tokens(tokens['input_ids'].squeeze()) segments = [] current_word = "" for token, label_id in zip(tokens_decoded, predictions): if token.startswith("▁"): # handle wordpieces if current_word: segments.append((current_word, id2label[str(label_id)])) current_word = token.replace("▁", "") # start a new word else: current_word += token # append subword part to current word if current_word: segments.append((current_word, id2label[str(label_id)])) return segments with gr.Blocks(theme=gr.themes.Base()) as demo: with gr.Row(): gr.Markdown(title) with gr.Row(): with gr.Group(): gr.Markdown(description) with gr.Row(): with gr.Group(): gr.Markdown(modelinfor1) with gr.Group(): gr.Markdown(modelinfor2) with gr.Group(): gr.Markdown(modelinfor3) with gr.Accordion(label="Join Us", open=False): gr.Markdown(joinus) with gr.Row(): input_text = gr.Textbox(label="Enter your text hereπŸ‘‡πŸ»", lines=5, placeholder=placeholder) output_text = gr.HighlightedText(label=" PLeIAs/βœ‚οΈπŸ“œ Segment Text", color_map=color_map, combine_adjacent=True, show_inline_category=True, show_legend=True) def process(input_text): return segment_text(input_text) submit_button = gr.Button("Segment Text") submit_button.click(fn=process, inputs=input_text, outputs=output_text) if __name__ == "__main__": demo.launch()