negarb commited on
Commit
488f448
·
1 Parent(s): e04caa2

Upload 2 files

Browse files
Files changed (2) hide show
  1. run.py +44 -0
  2. trainer.py +15 -0
run.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from fire import Fire
3
+ import string
4
+ import tensorflow as tf
5
+ from transformers import AutoTokenizer
6
+ from hazm import *
7
+ from transformers import pipeline
8
+ from transformers import TextClassificationPipeline
9
+ original_model = "HooshvareLab/bert-fa-base-uncased"
10
+ model_path = 'models'
11
+ def remove_punctuation(input_string):
12
+ translator = str.maketrans("", "", string.punctuation)
13
+
14
+ result = input_string.translate(translator)
15
+ return result
16
+ def predict(file_path):
17
+ normalizer = Normalizer()
18
+ tokenizer = AutoTokenizer.from_pretrained(original_model)
19
+ # classifier = pipeline("text-classification", model="stevhliu/my_awesome_model")
20
+
21
+ with open(file_path, 'r') as file:
22
+ text = file.read()
23
+
24
+ text = remove_punctuation(text)
25
+ text = normalizer.normalize(text)
26
+
27
+ input_tokens = tokenizer.batch_encode_plus(
28
+ [text],
29
+ padding=True,
30
+ truncation=True,
31
+ return_tensors="tf",
32
+ max_length=128
33
+ )
34
+ input_ids = input_tokens["input_ids"]
35
+ attention_mask = input_tokens["attention_mask"]
36
+ new_model = tf.keras.models.load_model(model_path)
37
+ # pipe = TextClassificationPipeline(model=new_model, tokenizer=tokenizer, return_all_scores=True)
38
+
39
+ print({"input_ids": input_ids, "attention_mask": attention_mask})
40
+ predictions = new_model.predict([{"input_ids": input_ids, "attention_mask": attention_mask}])
41
+ print(predictions[0])
42
+ # print(pipe([text]))
43
+ if __name__ == '__main__':
44
+ Fire(predict)
trainer.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fire import Fire
2
+ from src.classifier.classifier import get_model
3
+ from src.dataset.dataset import prepare_dataset
4
+ from src.utils import load_config_file
5
+
6
+ def trainer(config_path):
7
+ config = load_config_file(config_path)
8
+ dataset = prepare_dataset(config)
9
+ model = get_model(config,dataset)
10
+ model.train()
11
+ model.save_model_results()
12
+
13
+ if __name__ == '__main__':
14
+ Fire(trainer)
15
+