File size: 1,565 Bytes
488f448 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
from fire import Fire
import string
import tensorflow as tf
from transformers import AutoTokenizer
from hazm import *
from transformers import pipeline
from transformers import TextClassificationPipeline
original_model = "HooshvareLab/bert-fa-base-uncased"
model_path = 'models'
def remove_punctuation(input_string):
translator = str.maketrans("", "", string.punctuation)
result = input_string.translate(translator)
return result
def predict(file_path):
normalizer = Normalizer()
tokenizer = AutoTokenizer.from_pretrained(original_model)
# classifier = pipeline("text-classification", model="stevhliu/my_awesome_model")
with open(file_path, 'r') as file:
text = file.read()
text = remove_punctuation(text)
text = normalizer.normalize(text)
input_tokens = tokenizer.batch_encode_plus(
[text],
padding=True,
truncation=True,
return_tensors="tf",
max_length=128
)
input_ids = input_tokens["input_ids"]
attention_mask = input_tokens["attention_mask"]
new_model = tf.keras.models.load_model(model_path)
# pipe = TextClassificationPipeline(model=new_model, tokenizer=tokenizer, return_all_scores=True)
print({"input_ids": input_ids, "attention_mask": attention_mask})
predictions = new_model.predict([{"input_ids": input_ids, "attention_mask": attention_mask}])
print(predictions[0])
# print(pipe([text]))
if __name__ == '__main__':
Fire(predict) |