File size: 2,383 Bytes
85a5010
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from nltk.tokenize import word_tokenize
from nltk import pos_tag
import torch
import json

def batch_texts_POS_analysis(batch_texts, pos_templete, device="cuda"):
    batch_size = len(batch_texts)
    pos_tags = []
    pos_scores = torch.zeros(batch_size)

    for b_id in range(batch_size):
        text = batch_texts[b_id]
        words = word_tokenize(text)
        word_tag = pos_tag(words, tagset="universal")
        res_tag = [tag[1] for tag in word_tag]
        total_num = len(pos_templete)
        correct = 0
        if len(res_tag) <= total_num:
            cur_tag = res_tag + [""] * (len(pos_templete)-len(res_tag))
        else:
            cur_tag = res_tag[:total_num]
        for word_id in range(len(cur_tag)):
            if pos_templete[word_id]=="":
                correct += 1
            elif cur_tag[word_id] in pos_templete[word_id]:
                correct +=1
        acc = correct/total_num
        pos_tags.append(res_tag)
        pos_scores[b_id] = acc

    return pos_tags, pos_scores

def text_POS_analysis(text):
    words = word_tokenize(text)
    word_tag = pos_tag(words, tagset="universal")
    res_tag = [tag[1] for tag in word_tag]

    return res_tag

if __name__=="__main__":
    batch_texts = ["A cat sitting in the bed.",
                   "Two men in a nice hotel room one playing a video game with a remote control.",
                   "The man sitting in the chair feels like an invisible,dead man."]
    pos_templete = ['DET', 'NOUN', 'ADP', 'ADJ', 'NOUN', '.', 'NOUN', 'CONJ', 'NOUN', 'ADP', 'PRON', '.']

    batch_texts_POS_analysis(batch_texts, pos_templete, device="cuda")
    cur_path = "iter_15.json"
    all_caption = []

    with open(cur_path, "r") as cur_json_file:
        all_res = list(json.load(cur_json_file).values())
        for res in all_res:
            if isinstance(res, list):
                all_caption += res
            else:
                all_caption.append(res)
        pos_tags, pos_scores = batch_texts_POS_analysis(all_caption, pos_templete, device="cuda")
        word_id = 12
        pos_dict = {"ADJ": 0, "ADP": 0, "ADV": 0,
                    "CONJ": 0, "DET": 0, "NOUN": 0,"X":0,
                    "NUM": 0, "PRT": 0, "PRON": 0, "VERB": 0, ".": 0}
        for pos_tag in pos_tags:
            if word_id < len(pos_tag):
                pos_dict[pos_tag[word_id]] += 1
        print(1)