import torch from transformers import AutoTokenizer, BigBirdForQuestionAnswering from datasets import load_dataset tokenizer = AutoTokenizer.from_pretrained("google/bigbird-roberta-base") model = BigBirdForQuestionAnswering.from_pretrained("google/bigbird-roberta-base") squad_ds = load_dataset("squad_v2", split="train") # select random article and question LONG_ARTICLE = squad_ds[81514]["context"] QUESTION = squad_ds[81514]["question"] QUESTION inputs = tokenizer(QUESTION, LONG_ARTICLE, return_tensors="pt") # long article and question input list(inputs["input_ids"].shape) with torch.no_grad(): outputs = model(**inputs) answer_start_index = outputs.start_logits.argmax() answer_end_index = outputs.end_logits.argmax() predict_answer_token_ids = inputs.input_ids[0, answer_start_index : answer_end_index + 1] predict_answer_token = tokenizer.decode(predict_answer_token_ids)