File size: 1,440 Bytes
04cba07
7fbfa44
 
04cba07
884e1bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7fbfa44
 
 
 
884e1bf
 
 
 
 
036fae7
884e1bf
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import gradio as gr
import nltk
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

nltk.download('punkt')

def fragment_text(text, tokenizer):
    sentences = nltk.tokenize.sent_tokenize(text)
    max_len = tokenizer.max_len_single_sentence

    chunks = []
    chunk = ""
    count = -1

    for sentence in sentences:
        count += 1
        combined_length = len(tokenizer.tokenize(sentence)) + len(chunk)

        if combined_length <= max_len:
            chunk += sentence + " "
        else:
            chunks.append(chunk.strip())
            chunk = sentence + " "

    if chunk != "":
        chunks.append(chunk.strip())

    return chunks


def summarize_text(text, tokenizer, model):
    chunks = fragment_text(text, tokenizer)

    summaries = []
    for chunk in chunks:
        input = tokenizer(chunk, return_tensors='pt')
        output = model.generate(**input)
        summary = tokenizer.decode(*output, skip_special_tokens=True)
        summaries.append(summary)

    final_summary = " ".join(summaries)
    return final_summary

# Load pre-trained model and tokenizer
checkpoint = "tclopess/bart_samsum"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)

# Define Gradio Interface
iface = gr.Interface(
    fn=summarize_text,
    inputs=gr.Textbox(),
    outputs=gr.Textbox(),
    live=True
)

# Launch the Gradio Interface
iface.launch()