mojtaba-nafez's picture
Duplicate from mojtaba-nafez/persian-poem-recommender-based-on-text
1bc9b9d
raw
history blame
2.86 kB
from utils import get_datasets, build_loaders
from models import PoemTextModel
from train import train, test
from metrics import calc_metrics
from inference import predict_poems_from_text
from utils import get_poem_embeddings
import config as CFG
import json
def main():
"""
Creates a PoemTextModel based on configs and trains, tests and outputs some examples of its prediction.
"""
train_or_not = input("Train a new CLIP model using text embeddings? (needs the sajjadayobi360/cc3mfav2 and adityajn105/flickr8k datasets to be downloaded)\n[Y/N]")
if train_or_not == 'Y':
# Please download sajjadayobi360/cc3mfav2 and adityajn105/flickr8k datasets from kaggle
# !kaggle datasets download -d sajjadayobi360/cc3mfav2
# !kaggle datasets download -d adityajn105/flickr8k
#.... TODO
clip_dataset_dict = []
# get dataset from dataset_path (the same datasets as the train, val and test dataset files in the data directory is made)
train_dataset, val_dataset, test_dataset = get_clip_datasets(clip_dataset_dict)
train_loader = build_image_loaders(train_dataset, mode="train")
valid_loader = build_image_loaders(val_dataset, mode="valid")
# train a PoemTextModel and write its loss history in a file
model = CLIPModel(image_encoder_pretrained=True,
text_encoder_pretrained=True,
text_projection_trainable=False,
is_image_poem_pair=False
).to(CFG.device)
model, loss_history = train(model, train_loader, valid_loader)
with open('loss_history_{}_{}.json'.format(CFG.poem_encoder_model, CFG.text_encoder_model),'w', encoding="utf-8") as f:
f.write(json.dumps(loss_history, indent= 4))
# Inference: Get a filename and output predictions then write them in a file
print("_"*20)
print("INFERENCE PHASE")
model = CLIPModel(image_encoder_pretrained=True,
text_encoder_pretrained=True,
text_projection_trainable=False,
is_image_poem_pair=True
).to(CFG.device)
model.eval()
with open(CFG.dataset_path, encoding="utf-8") as f:
dataset = json.load(f)
model, poem_embeddings = get_poem_embeddings(test_dataset, model)
while(True):
image_filename = input("Enter an image filename to predict poems for")
beyts = predict_poems_from_image(model, poem_embeddings, image_filename, [data['beyt'] for data in dataset], n=10)
print("predicted Beyts: \n\t", "\n\t".join(beyts))
with open('{}_output__{}_{}.json'.format(image_filename, CFG.poem_encoder_model, CFG.text_encoder_model),'a+', encoding="utf-8") as f:
f.write(json.dumps(beyts, ensure_ascii=False, indent= 4))
if __name__ == "__main__":
main()