File size: 2,200 Bytes
fa77629
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch
from transformers import AutoTokenizer, GenerationConfig, TextStreamer, AutoModelForSeq2SeqLM
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import time

checkpoint = "Mia2024/CS5100TextSummarization"
checkpoint = "facebook/bart-large-cnn"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class StreamlitTextStreamer(TextStreamer):
    def __init__(self, tokenizer, st_container, st_info_container, skip_prompt=False, **decode_kwargs):
        super().__init__(tokenizer, skip_prompt, **decode_kwargs)
        self.st_container = st_container
        self.st_info_container = st_info_container
        self.text = ""
        self.start_time = None
        self.first_token_time = None
        self.total_tokens = 0

    def on_finalized_text(self, text: str, stream_end: bool=False):
        if self.start_time is None:
            self.start_time = time.time()

        if self.first_token_time is None and len(text.strip()) > 0:
            self.first_token_time = time.time()

        self.text += text

        self.total_tokens += len(text.split())
        self.st_container.markdown("###### " + self.text)
        time.sleep(0.03)


def generate_summary(input_text, st_container, st_info_container) -> str:
    generation_config = GenerationConfig(
            min_new_tokens=10,
            max_new_tokens=256,
            temperature=0.9,
            top_p=1.0,
            top_k=50         
        )
    tokenizer = AutoTokenizer.from_pretrained(checkpoint)
    model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint).to(device)
    prefix = "Summarize the following conversation: \n###\n"
    suffix = "\n### Summary:"
    target_length = max(1, int(0.15 * len(input_text.split())))

    input_ids = tokenizer.encode(prefix + input_text + f"The generated summary should be around {target_length} words." + suffix, return_tensors="pt")

    # Initialize the Streamlit container and streamer
    streamer = StreamlitTextStreamer(tokenizer, st_container, st_info_container, skip_special_tokens=True, decoder_start_token_id=3)

    model.generate(input_ids, streamer=streamer, do_sample=True, generation_config=generation_config)