minnehwg commited on
Commit
f9f0b1e
·
verified ·
1 Parent(s): 2878127

Create util.py

Browse files
Files changed (1) hide show
  1. util.py +2 -142
util.py CHANGED
@@ -1,142 +1,2 @@
1
- from datasets import Dataset
2
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainer, TrainingArguments
3
- from youtube_transcript_api import YouTubeTranscriptApi
4
- from deepmultilingualpunctuation import PunctuationModel
5
- from googletrans import Translator
6
- import time
7
- import torch
8
- import re
9
-
10
-
11
- def load_model(cp):
12
- tokenizer = AutoTokenizer.from_pretrained("VietAI/vit5-base")
13
- model = AutoModelForSeq2SeqLM.from_pretrained(cp)
14
- return tokenizer, model
15
-
16
-
17
- def summarize(text, model, tokenizer, num_beams=4, device='cpu'):
18
- model.to(device)
19
- inputs = tokenizer.encode(text, return_tensors="pt", max_length=1024, truncation=True, padding = True).to(device)
20
-
21
- with torch.no_grad():
22
- summary_ids = model.generate(inputs, max_length=256, num_beams=num_beams)
23
- summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
24
-
25
- return summary
26
-
27
-
28
- def processed(text):
29
- processed_text = text.replace('\n', ' ')
30
- processed_text = processed_text.lower()
31
- return processed_text
32
-
33
-
34
- def get_subtitles(video_url):
35
- try:
36
- video_id = video_url.split("v=")[1]
37
- transcript = YouTubeTranscriptApi.get_transcript(video_id, languages=['en'])
38
- subs = " ".join(entry['text'] for entry in transcript)
39
- print(subs)
40
-
41
- return transcript, subs
42
-
43
- except Exception as e:
44
- return [], f"An error occurred: {e}"
45
-
46
- from youtube_transcript_api import YouTubeTranscriptApi
47
-
48
-
49
- def restore_punctuation(text):
50
- model = PunctuationModel()
51
- result = model.restore_punctuation(text)
52
- return result
53
-
54
-
55
- def translate_long(text, language='vi'):
56
- translator = Translator()
57
- limit = 4700
58
- chunks = []
59
- current_chunk = ''
60
-
61
- sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', text)
62
-
63
- for sentence in sentences:
64
- if len(current_chunk) + len(sentence) <= limit:
65
- current_chunk += sentence.strip() + ' '
66
- else:
67
- chunks.append(current_chunk.strip())
68
- current_chunk = sentence.strip() + ' '
69
-
70
- if current_chunk:
71
- chunks.append(current_chunk.strip())
72
-
73
- translated_text = ''
74
-
75
- for chunk in chunks:
76
- try:
77
- time.sleep(1)
78
- translation = translator.translate(chunk, dest=language)
79
- translated_text += translation.text + ' '
80
- except Exception as e:
81
- translated_text += chunk + ' '
82
-
83
- return translated_text.strip()
84
-
85
- def split_into_chunks(text, max_words=800, overlap_sentences=2):
86
- sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', text)
87
-
88
- chunks = []
89
- current_chunk = []
90
- current_word_count = 0
91
-
92
- for sentence in sentences:
93
- word_count = len(sentence.split())
94
- if current_word_count + word_count <= max_words:
95
- current_chunk.append(sentence)
96
- current_word_count += word_count
97
- else:
98
- if len(current_chunk) >= overlap_sentences:
99
- overlap = current_chunk[-overlap_sentences:]
100
- chunks.append(' '.join(current_chunk))
101
- current_chunk = current_chunk[-overlap_sentences:] + [sentence]
102
- current_word_count = sum(len(sent.split()) for sent in current_chunk)
103
- if current_chunk:
104
- if len(current_chunk) >= overlap_sentences:
105
- overlap = current_chunk[-overlap_sentences:]
106
- chunks.append(' '.join(current_chunk))
107
-
108
- return chunks
109
-
110
-
111
- def post_processing(text):
112
- sentences = re.split(r'(?<=[.!?])\s*', text)
113
- for i in range(len(sentences)):
114
- if sentences[i]:
115
- sentences[i] = sentences[i][0].upper() + sentences[i][1:]
116
- text = " ".join(sentences)
117
- return text
118
-
119
-
120
- def display(text):
121
- sentences = re.split(r'(?<=[.!?])\s*', text)
122
- unique_sentences = list(dict.fromkeys(sentences[:-1]))
123
- formatted_sentences = [f"• {sentence}" for sentence in unique_sentences]
124
- return formatted_sentences
125
-
126
-
127
-
128
- def pipeline(url, model, tokenizer):
129
- trans, sub = get_subtitles(url)
130
- sub = restore_punctuation(sub)
131
- vie_sub = translate_long(sub)
132
- vie_sub = processed(vie_sub)
133
- chunks = split_into_chunks(vie_sub, 700, 2)
134
- sum_para = []
135
- for i in chunks:
136
- tmp = summarize(i, model, tokenizer, num_beams=3)
137
- sum_para.append(tmp)
138
- suma = ''.join(sum_para)
139
- del sub, vie_sub, sum_para, chunks
140
- suma = post_processing(suma)
141
- re = display(suma)
142
- return re
 
1
+ def update(name):
2
+ return f"Welcome to Gradio, {name}!"