File size: 2,959 Bytes
246d201
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import json
import logging
import os
from abc import ABC, abstractmethod

from utils import load_file

LOGGER = logging.getLogger('MINT')


class Task(ABC):
    """Base class for a task instance."""

    task_name: str = 'base'
    in_context_example_dir = os.path.join(
        os.path.dirname(os.path.abspath(__file__)),
        'in_context_examples',
    )

    def __init__(self, **kwargs) -> None:
        if 'loaded_history' in kwargs:
            self.loaded_history = kwargs['loaded_history']
        else:
            self.loaded_history = None
        # pre-load the in-context example
        task_dir = os.path.join(self.in_context_example_dir, self.task_name)
        self._in_context_example = {
            'with_tool': load_file(os.path.join(task_dir, 'with_tool.txt')),
        }
        self.metadata = {}

    @property
    def task_id(self) -> str:
        """Return the task id."""
        assert hasattr(self, '_id'), 'Task does not have an id.'
        return self._id

    def in_context_example(

        self, use_tool: bool = True, with_feedback: bool = False

    ) -> str:
        """Return the in-context example for the task."""
        if use_tool and not with_feedback:
            return self._in_context_example['with_tool']
        else:
            raise NotImplementedError

    @property
    def prompt(self) -> str:
        """Return the task prompt."""
        assert hasattr(self, '_prompt'), 'Task does not have a prompt.'
        return self._prompt

    @property
    def reference(self) -> str:
        """Return the reference solution for the task."""
        assert hasattr(self, '_reference'), 'Task does not have a reference solution.'
        return self._reference

    @abstractmethod
    def extract_answer(self, solution: str) -> str | None:
        """Extract the answer from the given solution."""
        pass

    @abstractmethod
    def success(self, solution: str) -> bool:
        """This checks whether the given solution can complete the current task.



        Can be used to provide binary feedback.

        """
        answer = self.extract_answer(solution)
        return answer == self.reference

    @classmethod
    def load_tasks(cls, path: str) -> tuple[list['Task'], int]:
        """Load all the tasks from a given jsonl file."""
        assert path.endswith('.jsonl') or path.endswith('.json')
        with open(path, 'r') as f:
            tasks = [cls(**json.loads(line)) for line in f.readlines()]
        LOGGER.info(f'Loaded {len(tasks)} tasks from {path}')
        return tasks, len(tasks)

    def to_dict(self) -> dict:
        """Convert the task to a dictionary."""
        return {
            'task_name': self.task_name,
            'task_id': self.task_id,
            'prompt': self.prompt,
            'reference': self.reference,
            'metadata': self.metadata,
        }