minnehwg commited on
Commit
dcd9c9b
·
verified ·
1 Parent(s): c510d8a

Delete util.py

Browse files
Files changed (1) hide show
  1. util.py +0 -143
util.py DELETED
@@ -1,143 +0,0 @@
1
- from datasets import Dataset
2
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainer, TrainingArguments
3
- from youtube_transcript_api import YouTubeTranscriptApi
4
- from deepmultilingualpunctuation import PunctuationModel
5
- from googletrans import Translator
6
- import time
7
- import torch
8
- import re
9
-
10
- # import httpcore
11
- # setattr(httpcore, 'SyncHTTPTransport', 'AsyncHTTPProxy')
12
-
13
-
14
- cp_aug = 'minnehwg/finetune-newwiki-summarization-ver-augmented2'
15
-
16
- def load_model(cp):
17
- tokenizer = AutoTokenizer.from_pretrained("VietAI/vit5-base")
18
- model = AutoModelForSeq2SeqLM.from_pretrained(cp)
19
- return tokenizer, model
20
-
21
-
22
- def summarize(text, model, tokenizer, num_beams=4, device='cpu'):
23
- model.to(device)
24
- inputs = tokenizer.encode(text, return_tensors="pt", max_length=1024, truncation=True, padding = True).to(device)
25
-
26
- with torch.no_grad():
27
- summary_ids = model.generate(inputs, max_length=256, num_beams=num_beams)
28
- summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
29
-
30
- return summary
31
-
32
-
33
- def processed(text):
34
- processed_text = text.replace('\n', ' ')
35
- processed_text = processed_text.lower()
36
- return processed_text
37
-
38
-
39
- def get_subtitles(video_url):
40
- try:
41
- video_id = video_url.split("v=")[1]
42
- transcript = YouTubeTranscriptApi.get_transcript(video_id, languages=['en'])
43
- subs = " ".join(entry['text'] for entry in transcript)
44
-
45
- return transcript, subs
46
-
47
- except Exception as e:
48
- return [], f"An error occurred: {e}"
49
-
50
-
51
- def restore_punctuation(text):
52
- model = PunctuationModel()
53
- result = model.restore_punctuation(text)
54
- return result
55
-
56
-
57
- def translate_long(text, language='vi'):
58
- translator = Translator()
59
- limit = 4700
60
- chunks = []
61
- current_chunk = ''
62
-
63
- sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', text)
64
-
65
- for sentence in sentences:
66
- if len(current_chunk) + len(sentence) <= limit:
67
- current_chunk += sentence.strip() + ' '
68
- else:
69
- chunks.append(current_chunk.strip())
70
- current_chunk = sentence.strip() + ' '
71
-
72
- if current_chunk:
73
- chunks.append(current_chunk.strip())
74
-
75
- translated_text = ''
76
-
77
- for chunk in chunks:
78
- try:
79
- time.sleep(1)
80
- translation = translator.translate(chunk, dest=language)
81
- translated_text += translation.text + ' '
82
- except Exception as e:
83
- translated_text += chunk + ' '
84
-
85
- return translated_text.strip()
86
-
87
- def split_into_chunks(text, max_words=800, overlap_sentences=2):
88
- sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', text)
89
-
90
- chunks = []
91
- current_chunk = []
92
- current_word_count = 0
93
-
94
- for sentence in sentences:
95
- word_count = len(sentence.split())
96
- if current_word_count + word_count <= max_words:
97
- current_chunk.append(sentence)
98
- current_word_count += word_count
99
- else:
100
- if len(current_chunk) >= overlap_sentences:
101
- overlap = current_chunk[-overlap_sentences:]
102
- print(f"Overlapping sentences: {' '.join(overlap)}")
103
- chunks.append(' '.join(current_chunk))
104
- current_chunk = current_chunk[-overlap_sentences:] + [sentence]
105
- current_word_count = sum(len(sent.split()) for sent in current_chunk)
106
- if current_chunk:
107
- if len(current_chunk) >= overlap_sentences:
108
- overlap = current_chunk[-overlap_sentences:]
109
- print(f"Overlapping sentences: {' '.join(overlap)}")
110
- chunks.append(' '.join(current_chunk))
111
-
112
- return chunks
113
-
114
-
115
- def post_processing(text):
116
- sentences = re.split(r'(?<=[.!?])\s*', text)
117
- for i in range(len(sentences)):
118
- if sentences[i]:
119
- sentences[i] = sentences[i][0].upper() + sentences[i][1:]
120
- text = " ".join(sentences)
121
- return text
122
-
123
- def display(text):
124
- sentences = re.split(r'(?<=[.!?])\s*', text)
125
- unique_sentences = list(dict.fromkeys(sentences[:-1]))
126
- formatted_sentences = [f"• {sentence}" for sentence in unique_sentences]
127
- return formatted_sentences
128
-
129
- def pipeline(url):
130
- trans, sub = get_subtitles(url)
131
- sub = restore_punctuation(sub)
132
- vie_sub = translate_long(sub)
133
- vie_sub = processed(vie_sub)
134
- chunks = split_into_chunks(vie_sub, 700, 3)
135
- sum_para = []
136
- for i in chunks:
137
- tmp = summarize(i, model_aug, tokenizer, num_beams=4)
138
- sum_para.append(tmp)
139
- sum = ''.join(sum_para)
140
- del sub, vie_sub, sum_para, chunks
141
- sum = post_processing(sum)
142
- re = display(sum)
143
- return re