File size: 1,043 Bytes
55b038b
e4f39c4
f2f4fc6
 
 
 
 
55dc8b1
e4f39c4
 
 
 
 
8d04b0f
e4f39c4
 
 
 
 
 
 
 
 
55dc8b1
e4f39c4
 
 
55dc8b1
e4f39c4
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import string

from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    Text2TextGenerationPipeline,
)


class KeyphraseGenerationPipeline(Text2TextGenerationPipeline):
    def __init__(self, model, keyphrase_sep_token=";", *args, **kwargs):
        super().__init__(
            model=AutoModelForSeq2SeqLM.from_pretrained(model),
            tokenizer=AutoTokenizer.from_pretrained(model, truncation=True),
            *args,
            **kwargs
        )
        self.keyphrase_sep_token = keyphrase_sep_token

    def postprocess(self, model_outputs):
        results = super().postprocess(model_outputs=model_outputs)
        return [
            [
                keyphrase.strip().translate(str.maketrans("", "", string.punctuation))
                for keyphrase in result.get("generated_text").split(
                    self.keyphrase_sep_token
                )
                if keyphrase.translate(str.maketrans("", "", string.punctuation)) != ""
            ]
            for result in results
        ][0]