File size: 4,601 Bytes
be29c20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainer, TrainingArguments
from youtube_transcript_api import YouTubeTranscriptApi
from deepmultilingualpunctuation import PunctuationModel
from googletrans import Translator
import time
import torch
import re

# import httpcore
# setattr(httpcore, 'SyncHTTPTransport', 'AsyncHTTPProxy')


cp_aug = 'minnehwg/finetune-newwiki-summarization-ver-augmented2'

def load_model(cp):
    tokenizer = AutoTokenizer.from_pretrained("VietAI/vit5-base")
    model  = AutoModelForSeq2SeqLM.from_pretrained(cp)
    return tokenizer, model 


def summarize(text, model, tokenizer, num_beams=4, device='cpu'):
    model.to(device)
    inputs = tokenizer.encode(text, return_tensors="pt", max_length=1024, truncation=True, padding = True).to(device)
    
    with torch.no_grad():
        summary_ids = model.generate(inputs, max_length=256, num_beams=num_beams)
    summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
    
    return summary


def processed(text):
    processed_text = text.replace('\n', ' ')
    processed_text = processed_text.lower()
    return processed_text


def get_subtitles(video_url):
    try:
        video_id = video_url.split("v=")[1]
        transcript = YouTubeTranscriptApi.get_transcript(video_id, languages=['en'])
        subs = " ".join(entry['text'] for entry in transcript)

        return transcript, subs

    except Exception as e:
        return [], f"An error occurred: {e}"


def restore_punctuation(text):
    model = PunctuationModel()
    result = model.restore_punctuation(text)
    return result


def translate_long(text, language='vi'):
    translator = Translator()
    limit = 4700
    chunks = []
    current_chunk = ''

    sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', text)

    for sentence in sentences:
        if len(current_chunk) + len(sentence) <= limit:
            current_chunk += sentence.strip() + ' '
        else:
            chunks.append(current_chunk.strip())
            current_chunk = sentence.strip() + ' '

    if current_chunk:
        chunks.append(current_chunk.strip())

    translated_text = ''

    for chunk in chunks:
        try:
            time.sleep(1)
            translation = translator.translate(chunk, dest=language)
            translated_text += translation.text + ' '
        except Exception as e:
            translated_text += chunk + ' '

    return translated_text.strip()

def split_into_chunks(text, max_words=800, overlap_sentences=2):
    sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', text)
    
    chunks = []
    current_chunk = []
    current_word_count = 0
    
    for sentence in sentences:
        word_count = len(sentence.split())
        if current_word_count + word_count <= max_words:
            current_chunk.append(sentence)
            current_word_count += word_count
        else:
            if len(current_chunk) >= overlap_sentences:
                overlap = current_chunk[-overlap_sentences:]
                print(f"Overlapping sentences: {' '.join(overlap)}")
            chunks.append(' '.join(current_chunk))
            current_chunk = current_chunk[-overlap_sentences:] + [sentence]
            current_word_count = sum(len(sent.split()) for sent in current_chunk)
    if current_chunk:
        if len(current_chunk) >= overlap_sentences:
            overlap = current_chunk[-overlap_sentences:]
            print(f"Overlapping sentences: {' '.join(overlap)}")
        chunks.append(' '.join(current_chunk))
    
    return chunks


def post_processing(text):
    sentences = re.split(r'(?<=[.!?])\s*', text)
    for i in range(len(sentences)):
        if sentences[i]:
            sentences[i] = sentences[i][0].upper() + sentences[i][1:]
    text = " ".join(sentences)
    return text

def display(text):
    sentences = re.split(r'(?<=[.!?])\s*', text)
    unique_sentences = list(dict.fromkeys(sentences[:-1]))
    formatted_sentences = [f"• {sentence}" for sentence in unique_sentences]
    return formatted_sentences

def pipeline(url):
    trans, sub = get_subtitles(url)
    sub = restore_punctuation(sub)
    vie_sub = translate_long(sub)
    vie_sub = processed(vie_sub)
    chunks = split_into_chunks(vie_sub, 700, 3)
    sum_para = []
    for i in chunks:
        tmp = summarize(i, model_aug, tokenizer, num_beams=4)
        sum_para.append(tmp)
    sum = ''.join(sum_para)
    del sub, vie_sub, sum_para, chunks
    sum = post_processing(sum)
    re = display(sum)
    return re