Upload 2 files
Browse files- run.py +44 -0
- trainer.py +15 -0
run.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from fire import Fire
|
3 |
+
import string
|
4 |
+
import tensorflow as tf
|
5 |
+
from transformers import AutoTokenizer
|
6 |
+
from hazm import *
|
7 |
+
from transformers import pipeline
|
8 |
+
from transformers import TextClassificationPipeline
|
9 |
+
original_model = "HooshvareLab/bert-fa-base-uncased"
|
10 |
+
model_path = 'models'
|
11 |
+
def remove_punctuation(input_string):
|
12 |
+
translator = str.maketrans("", "", string.punctuation)
|
13 |
+
|
14 |
+
result = input_string.translate(translator)
|
15 |
+
return result
|
16 |
+
def predict(file_path):
|
17 |
+
normalizer = Normalizer()
|
18 |
+
tokenizer = AutoTokenizer.from_pretrained(original_model)
|
19 |
+
# classifier = pipeline("text-classification", model="stevhliu/my_awesome_model")
|
20 |
+
|
21 |
+
with open(file_path, 'r') as file:
|
22 |
+
text = file.read()
|
23 |
+
|
24 |
+
text = remove_punctuation(text)
|
25 |
+
text = normalizer.normalize(text)
|
26 |
+
|
27 |
+
input_tokens = tokenizer.batch_encode_plus(
|
28 |
+
[text],
|
29 |
+
padding=True,
|
30 |
+
truncation=True,
|
31 |
+
return_tensors="tf",
|
32 |
+
max_length=128
|
33 |
+
)
|
34 |
+
input_ids = input_tokens["input_ids"]
|
35 |
+
attention_mask = input_tokens["attention_mask"]
|
36 |
+
new_model = tf.keras.models.load_model(model_path)
|
37 |
+
# pipe = TextClassificationPipeline(model=new_model, tokenizer=tokenizer, return_all_scores=True)
|
38 |
+
|
39 |
+
print({"input_ids": input_ids, "attention_mask": attention_mask})
|
40 |
+
predictions = new_model.predict([{"input_ids": input_ids, "attention_mask": attention_mask}])
|
41 |
+
print(predictions[0])
|
42 |
+
# print(pipe([text]))
|
43 |
+
if __name__ == '__main__':
|
44 |
+
Fire(predict)
|
trainer.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fire import Fire
|
2 |
+
from src.classifier.classifier import get_model
|
3 |
+
from src.dataset.dataset import prepare_dataset
|
4 |
+
from src.utils import load_config_file
|
5 |
+
|
6 |
+
def trainer(config_path):
|
7 |
+
config = load_config_file(config_path)
|
8 |
+
dataset = prepare_dataset(config)
|
9 |
+
model = get_model(config,dataset)
|
10 |
+
model.train()
|
11 |
+
model.save_model_results()
|
12 |
+
|
13 |
+
if __name__ == '__main__':
|
14 |
+
Fire(trainer)
|
15 |
+
|