Update pipeline.py
Browse files- pipeline.py +3 -3
pipeline.py
CHANGED
@@ -6,7 +6,7 @@ import html.parser
|
|
6 |
import unicodedata
|
7 |
import sys, os, re
|
8 |
|
9 |
-
class
|
10 |
|
11 |
def __init__(self, beam_size=5, batch_size=32, **kwargs):
|
12 |
self.beam_size = beam_size
|
@@ -153,7 +153,7 @@ class ReaccentPipeline(Pipeline):
|
|
153 |
def normalise_text(list_sents, batch_size=32, beam_size=5):
|
154 |
tokeniser = AutoTokenizer.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
|
155 |
model = AutoModelForSeq2SeqLM.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
|
156 |
-
normalisation_pipeline =
|
157 |
tokenizer=tokeniser,
|
158 |
batch_size=batch_size,
|
159 |
beam_size=beam_size)
|
@@ -163,7 +163,7 @@ def normalise_text(list_sents, batch_size=32, beam_size=5):
|
|
163 |
def normalise_from_stdin(batch_size=32, beam_size=5):
|
164 |
tokeniser = AutoTokenizer.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
|
165 |
model = AutoModelForSeq2SeqLM.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
|
166 |
-
normalisation_pipeline =
|
167 |
tokenizer=tokeniser,
|
168 |
batch_size=batch_size,
|
169 |
beam_size=beam_size)
|
|
|
6 |
import unicodedata
|
7 |
import sys, os, re
|
8 |
|
9 |
+
class NormalisationPipeline(Pipeline):
|
10 |
|
11 |
def __init__(self, beam_size=5, batch_size=32, **kwargs):
|
12 |
self.beam_size = beam_size
|
|
|
153 |
def normalise_text(list_sents, batch_size=32, beam_size=5):
|
154 |
tokeniser = AutoTokenizer.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
|
155 |
model = AutoModelForSeq2SeqLM.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
|
156 |
+
normalisation_pipeline = NormalisationPipeline(model=model,
|
157 |
tokenizer=tokeniser,
|
158 |
batch_size=batch_size,
|
159 |
beam_size=beam_size)
|
|
|
163 |
def normalise_from_stdin(batch_size=32, beam_size=5):
|
164 |
tokeniser = AutoTokenizer.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
|
165 |
model = AutoModelForSeq2SeqLM.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
|
166 |
+
normalisation_pipeline = NormalisationPipeline(model=model,
|
167 |
tokenizer=tokeniser,
|
168 |
batch_size=batch_size,
|
169 |
beam_size=beam_size)
|