MarMont commited on
Commit
7406912
1 Parent(s): 9da4768

fix transformers

Browse files
Files changed (1) hide show
  1. app.py +7 -6
app.py CHANGED
@@ -32,6 +32,8 @@ from operator import itemgetter
32
 
33
  import gradio as gr
34
 
 
 
35
  global df
36
  bearer_token = 'AAAAAAAAAAAAAAAAAAAAACEigwEAAAAACoP8KHJYLOKCL4OyB9LEPV00VB0%3DmyeDROUvw4uipHwvbPPfnTuY0M9ORrLuXrMvcByqZhwo3SUc4F'
37
  client = tweepy.Client(bearer_token=bearer_token)
@@ -392,11 +394,10 @@ def full_lda(df):
392
  return top_tweets
393
 
394
  def topic_summarization(topic_groups):
395
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
396
 
397
- model = T5ForConditionalGeneration.from_pretrained("Michau/t5-base-en-generate-headline")
398
- tokenizer = T5Tokenizer.from_pretrained("Michau/t5-base-en-generate-headline")
399
- model = model.to(device)
400
  translator = Translator()
401
 
402
  headlines = []
@@ -410,8 +411,8 @@ def topic_summarization(topic_groups):
410
  max_len = 256
411
 
412
  encoding = tokenizer.encode_plus(text, return_tensors = "pt")
413
- input_ids = encoding["input_ids"].to(device)
414
- attention_masks = encoding["attention_mask"].to(device)
415
 
416
  beam_outputs = model.generate(
417
  input_ids = input_ids,
 
32
 
33
  import gradio as gr
34
 
35
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
36
+
37
  global df
38
  bearer_token = 'AAAAAAAAAAAAAAAAAAAAACEigwEAAAAACoP8KHJYLOKCL4OyB9LEPV00VB0%3DmyeDROUvw4uipHwvbPPfnTuY0M9ORrLuXrMvcByqZhwo3SUc4F'
39
  client = tweepy.Client(bearer_token=bearer_token)
 
394
  return top_tweets
395
 
396
  def topic_summarization(topic_groups):
397
+
398
 
399
+ tokenizer = AutoTokenizer.from_pretrained("Michau/t5-base-en-generate-headline")
400
+ model = AutoModelForSeq2SeqLM.from_pretrained("Michau/t5-base-en-generate-headline")
 
401
  translator = Translator()
402
 
403
  headlines = []
 
411
  max_len = 256
412
 
413
  encoding = tokenizer.encode_plus(text, return_tensors = "pt")
414
+ input_ids = encoding["input_ids"]
415
+ attention_masks = encoding["attention_mask"]
416
 
417
  beam_outputs = model.generate(
418
  input_ids = input_ids,