File size: 4,478 Bytes
65b1238
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a146d7d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainer, TrainingArguments
from youtube_transcript_api import YouTubeTranscriptApi
from deepmultilingualpunctuation import PunctuationModel
from googletrans import Translator
import time
import torch
import re

def load_model(cp):
    tokenizer = AutoTokenizer.from_pretrained("VietAI/vit5-base")
    model  = AutoModelForSeq2SeqLM.from_pretrained(cp)
    return tokenizer, model 


def summarize(text, model, tokenizer, num_beams=4, device='cpu'):
    model.to(device)
    inputs = tokenizer.encode(text, return_tensors="pt", max_length=1024, truncation=True, padding = True).to(device)
    
    with torch.no_grad():
        summary_ids = model.generate(inputs, max_length=256, num_beams=num_beams)
    summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
    
    return summary


def processed(text):
    processed_text = text.replace('\n', ' ')
    processed_text = processed_text.lower()
    return processed_text


def get_subtitles(video_url):
    try:
        video_id = video_url.split("v=")[1]
        transcript = YouTubeTranscriptApi.get_transcript(video_id, languages=['en'])
        subs = " ".join(entry['text'] for entry in transcript)
        print(subs)

        return transcript, subs

    except Exception as e:
        return [], f"An error occurred: {e}"

from youtube_transcript_api import YouTubeTranscriptApi


def restore_punctuation(text):
    model = PunctuationModel()
    result = model.restore_punctuation(text)
    return result


def translate_long(text, language='vi'):
    translator = Translator()
    limit = 4700
    chunks = []
    current_chunk = ''

    sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', text)

    for sentence in sentences:
        if len(current_chunk) + len(sentence) <= limit:
            current_chunk += sentence.strip() + ' '
        else:
            chunks.append(current_chunk.strip())
            current_chunk = sentence.strip() + ' '

    if current_chunk:
        chunks.append(current_chunk.strip())

    translated_text = ''

    for chunk in chunks:
        try:
            time.sleep(1)
            translation = translator.translate(chunk, dest=language)
            translated_text += translation.text + ' '
        except Exception as e:
            translated_text += chunk + ' '

    return translated_text.strip()

def split_into_chunks(text, max_words=800, overlap_sentences=2):
    sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', text)
    
    chunks = []
    current_chunk = []
    current_word_count = 0
    
    for sentence in sentences:
        word_count = len(sentence.split())
        if current_word_count + word_count <= max_words:
            current_chunk.append(sentence)
            current_word_count += word_count
        else:
            if len(current_chunk) >= overlap_sentences:
                overlap = current_chunk[-overlap_sentences:]
            chunks.append(' '.join(current_chunk))
            current_chunk = current_chunk[-overlap_sentences:] + [sentence]
            current_word_count = sum(len(sent.split()) for sent in current_chunk)
    if current_chunk:
        if len(current_chunk) >= overlap_sentences:
            overlap = current_chunk[-overlap_sentences:]
        chunks.append(' '.join(current_chunk))
    
    return chunks


def post_processing(text):
    sentences = re.split(r'(?<=[.!?])\s*', text)
    for i in range(len(sentences)):
        if sentences[i]:
            sentences[i] = sentences[i][0].upper() + sentences[i][1:]
    text = " ".join(sentences)
    return text


def display(text):
    sentences = re.split(r'(?<=[.!?])\s*', text)
    unique_sentences = list(dict.fromkeys(sentences[:-1]))
    formatted_sentences = [f"• {sentence}" for sentence in unique_sentences]
    return formatted_sentences



def pipeline(url, model, tokenizer):
    trans, sub = get_subtitles(url)
    sub = restore_punctuation(sub)
    vie_sub = translate_long(sub)
    vie_sub = processed(vie_sub)
    chunks = split_into_chunks(vie_sub, 700, 2)
    sum_para = []
    for i in chunks:
        tmp = summarize(i, model, tokenizer, num_beams=3)
        sum_para.append(tmp)
    suma = ''.join(sum_para)
    del sub, vie_sub, sum_para, chunks
    suma = post_processing(suma)
    re = display(suma)
    return re

def update(name):
    return f"Welcome to Gradio, {name}!"