sooolee commited on
Commit
6f91ca7
·
1 Parent(s): ea9c0cf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -38,7 +38,7 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
  peft_model_id = "sooolee/flan-t5-base-cnn-samsum-lora"
39
  config = PeftConfig.from_pretrained(peft_model_id)
40
  tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
41
- model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path, device_map='auto') # load_in_8bit=True,
42
  model = PeftModel.from_pretrained(model, peft_model_id, device_map='auto')
43
 
44
  def summarize(video_id):
@@ -51,12 +51,12 @@ def summarize(video_id):
51
  transcript += dict[i]['text']
52
 
53
  texts = preprocessing(transcript)
54
- inputs = tokenizer(*texts, return_tensors="pt", padding=True, )
55
  inputs = inputs["input_ids"].to(device)
56
 
57
  with torch.no_grad():
58
- output_tokens = model.generate(*inputs, max_new_tokens=60, do_sample=True, top_p=0.9)
59
- outputs = tokenizer.batch_decode(output_tokens[0].detach().cpu().numpy(), skip_special_tokens=True)
60
 
61
  return outputs
62
 
 
38
  peft_model_id = "sooolee/flan-t5-base-cnn-samsum-lora"
39
  config = PeftConfig.from_pretrained(peft_model_id)
40
  tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
41
+ model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path, load_in_8bit=True, device_map='auto') # load_in_8bit=True,
42
  model = PeftModel.from_pretrained(model, peft_model_id, device_map='auto')
43
 
44
  def summarize(video_id):
 
51
  transcript += dict[i]['text']
52
 
53
  texts = preprocessing(transcript)
54
+ inputs = tokenizer(texts, return_tensors="pt", padding=True, )
55
  inputs = inputs["input_ids"].to(device)
56
 
57
  with torch.no_grad():
58
+ output_tokens = model.generate(inputs, max_new_tokens=60, do_sample=True, top_p=0.9)
59
+ outputs = tokenizer.batch_decode(output_tokens.detach().cpu().numpy(), skip_special_tokens=True)
60
 
61
  return outputs
62