quyip commited on
Commit
02f8d21
·
1 Parent(s): 5949767
Files changed (1) hide show
  1. utils/summary_utils.py +26 -14
utils/summary_utils.py CHANGED
@@ -3,12 +3,11 @@ from transformers import pipeline
3
 
4
  from utils.tag_utils import filter_tags
5
 
6
- AiSummaryVersion = 3
7
- MinTagScore = 0.7
8
- summarization_pipeline = pipeline("summarization", model="Falconsai/text_summarization")
9
  en_translation_pipe = pipeline("translation", model="Helsinki-NLP/opus-mt-mul-en")
10
- tag_gen_pipe_1 = pipeline("text-classification", model="dima806/news-category-classifier-distilbert")
11
- tag_gen_pipe_2 = pipeline("text-classification", model="elozano/bert-base-cased-news-category")
12
 
13
 
14
  def summarize(id: str, text: str):
@@ -16,10 +15,11 @@ def summarize(id: str, text: str):
16
  return {
17
  "ver": AiSummaryVersion
18
  }
19
- summary = get_summarization(text) if len(text) > 3000 else text
20
  translated = get_en_translation(summary)
21
- tags = get_tags(translated, id)
22
- tags = filter_tags(tags)
 
23
  tags = sorted(list(set(tags)))
24
 
25
  value = {
@@ -33,8 +33,7 @@ def summarize(id: str, text: str):
33
 
34
  def get_summarization(text: str):
35
  try:
36
- # Max / Min number of words
37
- result = summarization_pipeline(text, max_length=500, min_length=100, do_sample=False)
38
  return result[0]['summary_text'] if isinstance(result, list) else result['summary_text']
39
  except:
40
  return None
@@ -60,12 +59,25 @@ def is_english(text):
60
  return False
61
 
62
 
63
- def get_tags(text: str, id: str):
64
  if text is None:
65
  return []
66
  try:
67
- tags1 = [tag['label'] for tag in tag_gen_pipe_1(text) if tag['score'] >= MinTagScore]
68
- tags2 = [tag['label'] for tag in tag_gen_pipe_2(text) if tag['score'] >= MinTagScore]
69
- return tags1 + tags2
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  except:
71
  return []
 
3
 
4
  from utils.tag_utils import filter_tags
5
 
6
+ AiSummaryVersion = 4
7
+ summarization_pipeline = pipeline("summarization", model="csebuetnlp/mT5_multilingual_XLSum")
 
8
  en_translation_pipe = pipeline("translation", model="Helsinki-NLP/opus-mt-mul-en")
9
+ classification_pipe = pipeline("text-classification", model="Yueh-Huan/news-category-classification-distilbert")
10
+ text_to_tags_pipe = pipeline('text2text-generation', model='models/text2tags')
11
 
12
 
13
  def summarize(id: str, text: str):
 
15
  return {
16
  "ver": AiSummaryVersion
17
  }
18
+ summary = get_summarization(text) if len(text) > 100 else text
19
  translated = get_en_translation(summary)
20
+ tags1 = get_classification(translated)
21
+ tags2 = get_tags(translated)
22
+ tags = filter_tags(tags1 + tags2)
23
  tags = sorted(list(set(tags)))
24
 
25
  value = {
 
33
 
34
  def get_summarization(text: str):
35
  try:
36
+ result = summarization_pipeline(text)
 
37
  return result[0]['summary_text'] if isinstance(result, list) else result['summary_text']
38
  except:
39
  return None
 
59
  return False
60
 
61
 
62
+ def get_tags(text: str):
63
  if text is None:
64
  return []
65
  try:
66
+ result = text_to_tags_pipe(text)
67
+ tag_str = result[0]['generated_text'] if isinstance(result, list) else result['generated_text']
68
+ return [tag.strip() for tag in tag_str.split(',')]
69
+ except:
70
+ return []
71
+
72
+
73
+ def get_classification(text: str):
74
+ if text is None:
75
+ return []
76
+ try:
77
+ result = classification_pipe(text)
78
+ if isinstance(result, list):
79
+ return [tag['label'].strip() for tag in result if tag['score'] > 0.75]
80
+ else:
81
+ return [result['label'].strip()] if result['score'] > 0.75 else []
82
  except:
83
  return []