Spaces:
Runtime error
Runtime error
File size: 8,494 Bytes
1bc9b9d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
from __future__ import annotations
import torch
import cv2
import torch.nn.functional as F
import numpy as np
import config as CFG
from datasets import get_transforms
#for running this script as main
from utils import get_datasets, build_loaders
from models import PoemTextModel
from utils import get_poem_embeddings
import json
import os
import regex
def predict_poems_from_text(model, poem_embeddings, query, poems, text_tokenizer=None, n=10, return_similarities=False):
"""
Returns n poems which are the most similar to a text query
Parameters:
-----------
model: PoemTextModel
model to compute text query's embeddings
poem_embeddings: sequence with shape (#poems, CFG.projection_dim)
poem embeddings to check similarity
query: str
text query
poems: list of str
poems corresponding to poem_embeddings
text_tokenizer: huggingface Tokenizer, optional
tokenizer to tokenize query with. if none, will instantiate a new text tokenizer using configs.
n: int, optional
number of poems to return
return_similarities: bool, optional
if True, a dictionary will be returned which has the poem beyts and their similarities to the text
Returns:
--------
A list of n poem strings whose embeddings are the most similar to query text's embedding.
"""
#Tokenizing and Encoding the query text
if not text_tokenizer:
text_tokenizer = CFG.tokenizers[CFG.text_encoder_model].from_pretrained(CFG.text_tokenizer)
encoded_query = text_tokenizer([query])
batch = {
key: torch.tensor(values).to(CFG.device)
for key, values in encoded_query.items()
}
# getting query text's embeddings
model.eval()
with torch.no_grad():
text_features = model.text_encoder(
input_ids= batch["input_ids"], attention_mask=batch["attention_mask"]
)
text_embeddings = model.text_projection(text_features)
# normalizing and computing dot similarity of poem and text embeddings
poem_embeddings_n = F.normalize(poem_embeddings, p=2, dim=-1)
text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)
dot_similarity = text_embeddings_n @ poem_embeddings_n.T
# returning top n poems based on embedding similarity
values, indices = torch.topk(dot_similarity.squeeze(0), len(poems))
# since we collected poems from many sources, some of them are equal (the same beyt with different meanings),
# so we must check the poems added to result not to be duplicates
def is_poem_duplicate(poem, poems):
poem = regex.findall(r'\p{L}+', poem.replace('\u200c', ''))
for other_poem in poems:
other_poem = regex.findall(r'\p{L}+', other_poem.replace('\u200c', ''))
if poem == other_poem:
return True
return False
results = []
computed_k = 0
for i in range(len(poems)):
if computed_k == n:
break
if not is_poem_duplicate(poems[indices[i]], [res['beyt'] for res in results]):
results.append({
'beyt': poems[indices[i]].replace(' * * ', ' * ').replace('*** * ', ''),
'similarity': values[i]
})
computed_k += 1
if return_similarities:
return results
else:
return [res['beyt'] for res in results]
def predict_poems_from_image(model, poem_embeddings, image_filename, poems, n=10, return_similarities=False):
"""
Returns n poems which are the most similar to an image query
Parameters:
-----------
model: CLIPModel
model to compute image query's embeddings
poem_embeddings: sequence with shape (#poems, CFG.projection_dim)
poem embeddings to check similarity
image_filename: str
path and file name for the image query
poems: list of str
poems corresponding to poem_embeddings
n: int, optional
number of poems to return
return_similarities: bool, optional
if True, a dictionary will be returned which has the poem beyts and their similarities to the text
Returns:
--------
A list of n poem strings whose embeddings are the most similar to image query's embedding.
"""
# Reading, Processing and applying transforms to image (all explained in datasets.py)
image = cv2.imread(f"{image_filename}")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = get_transforms(mode="test")(image=image)['image']
image = torch.tensor(image).permute(2, 0, 1).float()
# getting image query's embeddings
model.eval()
with torch.no_grad():
image_features = model.image_encoder(torch.unsqueeze(image, 0).to(CFG.device))
image_embeddings = model.image_projection(image_features)
# normalizing and computing dot similarity of poem and text embeddings
poem_embeddings_n = F.normalize(poem_embeddings, p=2, dim=-1)
image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1)
dot_similarity = image_embeddings_n @ poem_embeddings_n.T
# returning top n poems based on embedding similarity
values, indices = torch.topk(dot_similarity.squeeze(0), len(poems))
# since we collected poems from many sources, some of them are equal (the same beyt with different meanings),
# so we must check the poems added to result not to be duplicates
def is_poem_duplicate(poem, poems):
poem = regex.findall(r'\p{L}+', poem.replace('\u200c', ''))
for other_poem in poems:
other_poem = regex.findall(r'\p{L}+', other_poem.replace('\u200c', ''))
if poem == other_poem:
return True
return False
results = []
computed_k = 0
for i in range(len(poems)):
if computed_k == n:
break
if not is_poem_duplicate(poems[indices[i]], [res['beyt'] for res in results]):
results.append({
'beyt': poems[indices[i]].replace(' * * ', ' * ').replace('*** * ', ''),
'similarity': values[i]
})
computed_k += 1
if return_similarities:
return results
else:
return [res['beyt'] for res in results]
if __name__ == "__main__":
"""
Creates a PoemTextModel based on configs, and outputs some examples of its prediction.
"""
# get dataset from dataset_path (the same datasets as the train, val and test dataset files in the data directory is made)
train_dataset, val_dataset, test_dataset = get_datasets()
model = PoemTextModel(poem_encoder_pretrained=True, text_encoder_pretrained=True).to(CFG.device)
model.eval()
# Inference: Output some example predictions and write them in a file
print("_"*20)
print("Output Examples from test set")
model, poem_embeddings = get_poem_embeddings(test_dataset, model)
example = {}
for i, test_data in enumerate(test_dataset[:100]):
example[i] = {'Text': test_data["text"], 'True Beyt': test_data["beyt"], "Predicted Beyt":predict_poems_from_text(model, poem_embeddings, test_data["text"], [data['beyt'] for data in test_dataset], n=10)}
for i in range(10):
print("Text: ", example[i]['Text'])
print("True Beyt: ", example[i]['True Beyt'])
print("predicted Beyts: \n\t", "\n\t".join(example[i]["Predicted Beyt"]))
with open('example_output__{}_{}.json'.format(CFG.poem_encoder_model, CFG.text_encoder_model),'w', encoding="utf-8") as f:
f.write(json.dumps(example, ensure_ascii=False, indent= 4))
print("Preparing model for user input...")
with open(CFG.dataset_path, encoding="utf-8") as f:
dataset = json.load(f)
model, poem_embeddings = get_poem_embeddings(dataset, model)
while(True):
user_text = input("Enter a Text to find poem beyts for: ")
beyts = predict_poems_from_text(model, poem_embeddings, user_text, [data['beyt'] for data in dataset], n=10)
print("predicted Beyts: \n\t", "\n\t".join(beyts))
with open('{}_output__{}_{}.json'.format(user_text, CFG.poem_encoder_model, CFG.text_encoder_model),'a+', encoding="utf-8") as f:
f.write(json.dumps(beyts, ensure_ascii=False, indent= 4)) |