File size: 2,321 Bytes
9d032cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import gradio as gr
from transformers import DebertaV2Tokenizer, DebertaV2ForTokenClassification
import torch


model_name = "PleIAs/Segmentext"
tokenizer = DebertaV2Tokenizer.from_pretrained(model_name)
model = DebertaV2ForTokenClassification.from_pretrained(model_name)

id2label = {
    0: "author", 1: "bibliography", 2: "caption", 3: "contact", 
    4: "date", 5: "dialog", 6: "footnote", 7: "keywords", 
    8: "math", 9: "paratext", 10: "separator", 11: "table", 
    12: "text", 13: "title"
}


color_map = {
    "author": "blue", "bibliography": "purple", "caption": "orange", 
    "contact": "cyan", "date": "green", "dialog": "yellow", 
    "footnote": "pink", "keywords": "lightblue", "math": "red", 
    "paratext": "lightgreen", "separator": "gray", "table": "brown", 
    "text": "lightgray", "title": "gold"
}


def segment_text(input_text):

    tokens = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)

    with torch.no_grad():
        outputs = model(**tokens)
    
    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1).squeeze().tolist()
    
    tokens_decoded = tokenizer.convert_ids_to_tokens(tokens['input_ids'].squeeze())
    
    segments = []
    current_word = ""
    for token, label_id in zip(tokens_decoded, predictions):
        if token.startswith("▁"):  # handling wordpieces, specific to some tokenizers
            if current_word:
                segments.append((current_word, id2label[label_id]))
            current_word = token.replace("▁", "")  # new word
        else:
            current_word += token  # append subword part to current word
    
    if current_word:
        segments.append((current_word, id2label[label_id]))
    
    return segments

with gr.Blocks() as demo:
    gr.Markdown("# PleIAs/Segmentext Text Segmentation Demo")
    
    with gr.Row():
        input_text = gr.Textbox(label="Input Text", lines=5, placeholder="Enter text for segmentation")
        output_text = gr.HighlightedText(label="Segmented Text", color_map=color_map, combine_adjacent=True)
    
    def process(input_text):
        return segment_text(input_text)
    
    submit_button = gr.Button("Segment Text")
    submit_button.click(fn=process, inputs=input_text, outputs=output_text)

if __name__ == "__main__":
    demo.launch()