File size: 1,777 Bytes
9d21d47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import numpy as np


class BaseProbInference:
    def __init__(self, prompt_version):
        if prompt_version == "default":
            self.prompt_version = self.default_prompt_version()
        else:
            self.prompt_version = prompt_version

        self.raw_data_result = None
        self.raw_data_sample = None
        self.raw_data_dev = None

        self.can_be_stratified = False
        self.CHOICES = None
        self.num_base_shot = 1

    def default_prompt_version(self):
        raise NotImplementedError

    def dataset_signature(self):
        # {
        #      "result":  (dataset_name, subset, split),  # which produce the final result
        #      "sample": (dataset_name, subset, split),  # which we sample ICL few-shot examples
        # }
        raise NotImplementedError

    def dataset_part(self, part):
        return self.dataset_signature()[part]

    def dataset_preprocess(self, raw_data):
        raise NotImplementedError

    def handcrafted_exemplars(self):
        raise NotImplementedError

    def exemplar_seperator(self):
        raise NotImplementedError

    def multiple_choice_promptify(self, query, choice):
        raise NotImplementedError

    @staticmethod
    def merge_choice_info(choice_info):
        merged = {}
        for k in ["lm_log_p", "norm_lm_log_p"]:
            one_metric_merged = []
            for info in choice_info:
                one_metric_merged.append(info[k])
            merged[k] = one_metric_merged
        return merged

    @staticmethod
    def choice_info_to_predictions(info):
        lm_log_p_idx = int(np.argmax(info["lm_log_p"]))
        norm_lm_log_p_idx = int(np.argmax(info["norm_lm_log_p"]))
        return {"lm_log_p": lm_log_p_idx, "norm_lm_log_p": norm_lm_log_p_idx}