File size: 2,741 Bytes
9d032cb
 
 
b9152b3
9d032cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1a6fc6
9d032cb
c1a6fc6
 
9d032cb
 
 
 
c1a6fc6
9d032cb
 
 
3413259
9d032cb
8bffd4f
 
c1a6fc6
 
3e95ae5
 
a519300
 
 
 
 
 
c1a6fc6
 
8bffd4f
 
 
9d032cb
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import gradio as gr
from transformers import DebertaV2Tokenizer, DebertaV2ForTokenClassification
import torch
from globe import title, description, joinus, model_name, placeholder, modelinfor1, modelinfor2, modelinfor3, id2label

tokenizer = DebertaV2Tokenizer.from_pretrained(model_name)
model = DebertaV2ForTokenClassification.from_pretrained(model_name)

color_map = {
    "author": "blue", "bibliography": "purple", "caption": "orange", 
    "contact": "cyan", "date": "green", "dialog": "yellow", 
    "footnote": "pink", "keywords": "lightblue", "math": "red", 
    "paratext": "lightgreen", "separator": "gray", "table": "brown", 
    "text": "lightgray", "title": "gold"
}


def segment_text(input_text):

    tokens = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)

    with torch.no_grad():
        outputs = model(**tokens)
    
    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1).squeeze().tolist()
    
    tokens_decoded = tokenizer.convert_ids_to_tokens(tokens['input_ids'].squeeze())
    
    segments = []
    current_word = ""
    for token, label_id in zip(tokens_decoded, predictions):
        if token.startswith("▁"):  # handle wordpieces
            if current_word:
                segments.append((current_word, id2label[str(label_id)]))
            current_word = token.replace("▁", "")  # start a new word
        else:
            current_word += token  # append subword part to current word
    
    if current_word:
        segments.append((current_word, id2label[str(label_id)]))
    
    return segments

with gr.Blocks(theme=gr.themes.Base()) as demo:
    with gr.Row():
        gr.Markdown(title)
    with gr.Row():
        with gr.Group():                    
            gr.Markdown(description)
    with gr.Accordion(label="PLeIAs/βœ‚οΈπŸ“œ Segment Text Model Information ", open=False):
        with gr.Row():
            with gr.Group():
                gr.Markdown(modelinfor1)
            with gr.Group():
                gr.Markdown(modelinfor2)
            with gr.Group():
                gr.Markdown(modelinfor3)
    with gr.Accordion(label="Join Us", open=False):        
        gr.Markdown(joinus)    
    with gr.Row():
        input_text = gr.Textbox(label="Enter your text hereπŸ‘‡πŸ»", lines=5, placeholder=placeholder)
        output_text = gr.HighlightedText(label=" PLeIAs/βœ‚οΈπŸ“œ Segment Text", color_map=color_map, combine_adjacent=True, show_inline_category=True, show_legend=True)
    def process(input_text):
        return segment_text(input_text)
    
    submit_button = gr.Button("Segment Text")
    submit_button.click(fn=process, inputs=input_text, outputs=output_text)

if __name__ == "__main__":
    demo.launch()