File size: 1,565 Bytes
488f448
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44

from fire import Fire
import string
import tensorflow as tf
from transformers import  AutoTokenizer
from hazm import *
from transformers import pipeline
from transformers import TextClassificationPipeline
original_model = "HooshvareLab/bert-fa-base-uncased"
model_path = 'models'
def remove_punctuation(input_string):
    translator = str.maketrans("", "", string.punctuation)

    result = input_string.translate(translator)
    return result
def predict(file_path):     
    normalizer = Normalizer() 
    tokenizer = AutoTokenizer.from_pretrained(original_model)
    # classifier = pipeline("text-classification", model="stevhliu/my_awesome_model")

    with open(file_path, 'r') as file:
        text = file.read()
    
        text = remove_punctuation(text)
        text = normalizer.normalize(text)
        
        input_tokens = tokenizer.batch_encode_plus(
            [text],
            padding=True,
            truncation=True,
            return_tensors="tf",
            max_length=128
            )
        input_ids = input_tokens["input_ids"]
        attention_mask = input_tokens["attention_mask"]
        new_model = tf.keras.models.load_model(model_path)
        # pipe = TextClassificationPipeline(model=new_model, tokenizer=tokenizer, return_all_scores=True)

        print({"input_ids": input_ids, "attention_mask": attention_mask})
        predictions = new_model.predict([{"input_ids": input_ids, "attention_mask": attention_mask}])
        print(predictions[0])
        # print(pipe([text]))
if __name__ == '__main__':
    Fire(predict)