mojtaba-nafez's picture
Duplicate from mojtaba-nafez/persian-poem-recommender-based-on-text
1bc9b9d
from __future__ import annotations
import torch
import cv2
import torch.nn.functional as F
import numpy as np
import config as CFG
from datasets import get_transforms
#for running this script as main
from utils import get_datasets, build_loaders
from models import PoemTextModel
from utils import get_poem_embeddings
import json
import os
import regex
def predict_poems_from_text(model, poem_embeddings, query, poems, text_tokenizer=None, n=10, return_similarities=False):
"""
Returns n poems which are the most similar to a text query
Parameters:
-----------
model: PoemTextModel
model to compute text query's embeddings
poem_embeddings: sequence with shape (#poems, CFG.projection_dim)
poem embeddings to check similarity
query: str
text query
poems: list of str
poems corresponding to poem_embeddings
text_tokenizer: huggingface Tokenizer, optional
tokenizer to tokenize query with. if none, will instantiate a new text tokenizer using configs.
n: int, optional
number of poems to return
return_similarities: bool, optional
if True, a dictionary will be returned which has the poem beyts and their similarities to the text
Returns:
--------
A list of n poem strings whose embeddings are the most similar to query text's embedding.
"""
#Tokenizing and Encoding the query text
if not text_tokenizer:
text_tokenizer = CFG.tokenizers[CFG.text_encoder_model].from_pretrained(CFG.text_tokenizer)
encoded_query = text_tokenizer([query])
batch = {
key: torch.tensor(values).to(CFG.device)
for key, values in encoded_query.items()
}
# getting query text's embeddings
model.eval()
with torch.no_grad():
text_features = model.text_encoder(
input_ids= batch["input_ids"], attention_mask=batch["attention_mask"]
)
text_embeddings = model.text_projection(text_features)
# normalizing and computing dot similarity of poem and text embeddings
poem_embeddings_n = F.normalize(poem_embeddings, p=2, dim=-1)
text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)
dot_similarity = text_embeddings_n @ poem_embeddings_n.T
# returning top n poems based on embedding similarity
values, indices = torch.topk(dot_similarity.squeeze(0), len(poems))
# since we collected poems from many sources, some of them are equal (the same beyt with different meanings),
# so we must check the poems added to result not to be duplicates
def is_poem_duplicate(poem, poems):
poem = regex.findall(r'\p{L}+', poem.replace('\u200c', ''))
for other_poem in poems:
other_poem = regex.findall(r'\p{L}+', other_poem.replace('\u200c', ''))
if poem == other_poem:
return True
return False
results = []
computed_k = 0
for i in range(len(poems)):
if computed_k == n:
break
if not is_poem_duplicate(poems[indices[i]], [res['beyt'] for res in results]):
results.append({
'beyt': poems[indices[i]].replace(' * * ', ' * ').replace('*** * ', ''),
'similarity': values[i]
})
computed_k += 1
if return_similarities:
return results
else:
return [res['beyt'] for res in results]
def predict_poems_from_image(model, poem_embeddings, image_filename, poems, n=10, return_similarities=False):
"""
Returns n poems which are the most similar to an image query
Parameters:
-----------
model: CLIPModel
model to compute image query's embeddings
poem_embeddings: sequence with shape (#poems, CFG.projection_dim)
poem embeddings to check similarity
image_filename: str
path and file name for the image query
poems: list of str
poems corresponding to poem_embeddings
n: int, optional
number of poems to return
return_similarities: bool, optional
if True, a dictionary will be returned which has the poem beyts and their similarities to the text
Returns:
--------
A list of n poem strings whose embeddings are the most similar to image query's embedding.
"""
# Reading, Processing and applying transforms to image (all explained in datasets.py)
image = cv2.imread(f"{image_filename}")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = get_transforms(mode="test")(image=image)['image']
image = torch.tensor(image).permute(2, 0, 1).float()
# getting image query's embeddings
model.eval()
with torch.no_grad():
image_features = model.image_encoder(torch.unsqueeze(image, 0).to(CFG.device))
image_embeddings = model.image_projection(image_features)
# normalizing and computing dot similarity of poem and text embeddings
poem_embeddings_n = F.normalize(poem_embeddings, p=2, dim=-1)
image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1)
dot_similarity = image_embeddings_n @ poem_embeddings_n.T
# returning top n poems based on embedding similarity
values, indices = torch.topk(dot_similarity.squeeze(0), len(poems))
# since we collected poems from many sources, some of them are equal (the same beyt with different meanings),
# so we must check the poems added to result not to be duplicates
def is_poem_duplicate(poem, poems):
poem = regex.findall(r'\p{L}+', poem.replace('\u200c', ''))
for other_poem in poems:
other_poem = regex.findall(r'\p{L}+', other_poem.replace('\u200c', ''))
if poem == other_poem:
return True
return False
results = []
computed_k = 0
for i in range(len(poems)):
if computed_k == n:
break
if not is_poem_duplicate(poems[indices[i]], [res['beyt'] for res in results]):
results.append({
'beyt': poems[indices[i]].replace(' * * ', ' * ').replace('*** * ', ''),
'similarity': values[i]
})
computed_k += 1
if return_similarities:
return results
else:
return [res['beyt'] for res in results]
if __name__ == "__main__":
"""
Creates a PoemTextModel based on configs, and outputs some examples of its prediction.
"""
# get dataset from dataset_path (the same datasets as the train, val and test dataset files in the data directory is made)
train_dataset, val_dataset, test_dataset = get_datasets()
model = PoemTextModel(poem_encoder_pretrained=True, text_encoder_pretrained=True).to(CFG.device)
model.eval()
# Inference: Output some example predictions and write them in a file
print("_"*20)
print("Output Examples from test set")
model, poem_embeddings = get_poem_embeddings(test_dataset, model)
example = {}
for i, test_data in enumerate(test_dataset[:100]):
example[i] = {'Text': test_data["text"], 'True Beyt': test_data["beyt"], "Predicted Beyt":predict_poems_from_text(model, poem_embeddings, test_data["text"], [data['beyt'] for data in test_dataset], n=10)}
for i in range(10):
print("Text: ", example[i]['Text'])
print("True Beyt: ", example[i]['True Beyt'])
print("predicted Beyts: \n\t", "\n\t".join(example[i]["Predicted Beyt"]))
with open('example_output__{}_{}.json'.format(CFG.poem_encoder_model, CFG.text_encoder_model),'w', encoding="utf-8") as f:
f.write(json.dumps(example, ensure_ascii=False, indent= 4))
print("Preparing model for user input...")
with open(CFG.dataset_path, encoding="utf-8") as f:
dataset = json.load(f)
model, poem_embeddings = get_poem_embeddings(dataset, model)
while(True):
user_text = input("Enter a Text to find poem beyts for: ")
beyts = predict_poems_from_text(model, poem_embeddings, user_text, [data['beyt'] for data in dataset], n=10)
print("predicted Beyts: \n\t", "\n\t".join(beyts))
with open('{}_output__{}_{}.json'.format(user_text, CFG.poem_encoder_model, CFG.text_encoder_model),'a+', encoding="utf-8") as f:
f.write(json.dumps(beyts, ensure_ascii=False, indent= 4))