File size: 4,360 Bytes
b222b37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc6ec96
b222b37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a3fa77
b222b37
 
 
 
e0d5b7a
b222b37
 
915c9b5
b222b37
4b73186
b222b37
4b73186
 
b222b37
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainer, TrainingArguments
from youtube_transcript_api import YouTubeTranscriptApi
from deepmultilingualpunctuation import PunctuationModel
from googletrans import Translator
import time
import torch
import re



def load_model(cp):
    tokenizer = AutoTokenizer.from_pretrained("VietAI/vit5-base")
    model  = AutoModelForSeq2SeqLM.from_pretrained(cp)
    return tokenizer, model 


def summarize(text, model, tokenizer, num_beams=4, device='cpu'):
    model.to(device)
    inputs = tokenizer.encode(text, return_tensors="pt", max_length=1024, truncation=True, padding = True).to(device)
    
    with torch.no_grad():
        summary_ids = model.generate(inputs, max_length=256, num_beams=num_beams)
    summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
    
    return summary


def processed(text):
    processed_text = text.replace('\n', ' ')
    processed_text = processed_text.lower()
    return processed_text


def get_subtitles(video_url):
    try:
        video_id = video_url.split("v=")[1]
        transcript = YouTubeTranscriptApi.get_transcript(video_id, languages=['en'])
        subs = " ".join(entry['text'] for entry in transcript)
        print(subs)

        return transcript, subs

    except Exception as e:
        return [], f"An error occurred: {e}"


def restore_punctuation(text):
    model = PunctuationModel()
    result = model.restore_punctuation(text)
    return result


def translate_long(text, language='vi'):
    translator = Translator()
    limit = 4700
    chunks = []
    current_chunk = ''

    sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', text)

    for sentence in sentences:
        if len(current_chunk) + len(sentence) <= limit:
            current_chunk += sentence.strip() + ' '
        else:
            chunks.append(current_chunk.strip())
            current_chunk = sentence.strip() + ' '

    if current_chunk:
        chunks.append(current_chunk.strip())

    translated_text = ''

    for chunk in chunks:
        try:
            time.sleep(1)
            translation = translator.translate(chunk, dest=language)
            translated_text += translation.text + ' '
        except Exception as e:
            translated_text += chunk + ' '

    return translated_text.strip()

def split_into_chunks(text, max_words=800, overlap_sentences=2):
    sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', text)
    
    chunks = []
    current_chunk = []
    current_word_count = 0
    
    for sentence in sentences:
        word_count = len(sentence.split())
        if current_word_count + word_count <= max_words:
            current_chunk.append(sentence)
            current_word_count += word_count
        else:
            if len(current_chunk) >= overlap_sentences:
                overlap = current_chunk[-overlap_sentences:]
            chunks.append(' '.join(current_chunk))
            current_chunk = current_chunk[-overlap_sentences:] + [sentence]
            current_word_count = sum(len(sent.split()) for sent in current_chunk)
    if current_chunk:
        if len(current_chunk) >= overlap_sentences:
            overlap = current_chunk[-overlap_sentences:]
        chunks.append(' '.join(current_chunk))
    
    return chunks


def post_processing(text):
    sentences = re.split(r'(?<=[.!?])\s*', text)
    for i in range(len(sentences)):
        if sentences[i]:
            sentences[i] = sentences[i][0].upper() + sentences[i][1:]
    text = " ".join(sentences)
    return text

def display(text):
    sentences = re.split(r'(?<=[.!?])\s*', text)
    unique_sentences = list(dict.fromkeys(sentences[:-1]))
    formatted_sentences = [f"• {sentence}" for sentence in unique_sentences]
    return formatted_sentences

def pipeline(url, model, tokenizer):
    trans, sub = get_subtitles(url)
    sub = restore_punctuation(sub)
    vie_sub = translate_long(sub)
    vie_sub = processed(vie_sub)
    chunks = split_into_chunks(vie_sub, 700, 2)
    sum_para = []
    for i in chunks:
        tmp = summarize(i, model, tokenizer, num_beams=3)
        sum_para.append(tmp)
    suma = ''.join(sum_para)
    del sub, vie_sub, sum_para, chunks
    suma = post_processing(suma)
    re = display(suma)
    return re