File size: 4,067 Bytes
594c559
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
import warnings
from typing import Any, Optional

from camel.prompts import TaskPromptTemplateDict, TextPrompt
from camel.typing import RoleType, TaskType


class PromptTemplateGenerator:
    r"""A class for generating prompt templates for tasks.

    Args:
        task_prompt_template_dict (TaskPromptTemplateDict, optional):
            A dictionary of task prompt templates for each task type. If not
            provided, an empty dictionary is used as default.
    """

    def __init__(
        self,
        task_prompt_template_dict: Optional[TaskPromptTemplateDict] = None,
    ) -> None:
        self.task_prompt_template_dict = (task_prompt_template_dict or TaskPromptTemplateDict())

    def get_prompt_from_key(self, task_type: TaskType, key: Any) -> TextPrompt:
        r"""Generates a text prompt using the specified :obj:`task_type` and
        :obj:`key`.

        Args:
            task_type (TaskType): The type of task.
            key (Any): The key used to generate the prompt.

        Returns:
            TextPrompt: The generated text prompt.

        Raises:
            KeyError: If failed to generate prompt using the specified
                :obj:`task_type` and :obj:`key`.
        """
        try:
            print(task_type, key)
            return self.task_prompt_template_dict[task_type][key]

        except KeyError:
            raise KeyError("Failed to get generate prompt template for "
                           f"task: {task_type.value} from key: {key}.")

    def get_system_prompt(
        self,
        task_type: TaskType,
        role_type: RoleType,
    ) -> TextPrompt:
        r"""Generates a text prompt for the system role, using the specified
        :obj:`task_type` and :obj:`role_type`.

        Args:
            task_type (TaskType): The type of task.
            role_type (RoleType): The type of role, either "USER" or
                "ASSISTANT".

        Returns:
            TextPrompt: The generated text prompt.

        Raises:
            KeyError: If failed to generate prompt using the specified
                :obj:`task_type` and :obj:`role_type`.
        """
        try:
            return self.get_prompt_from_key(task_type, role_type)

        except KeyError:
            prompt = "You are a helpful assistant."

            warnings.warn("Failed to get system prompt template for "
                          f"task: {task_type.value}, role: {role_type.value}. "
                          f"Set template to: {prompt}")

        return TextPrompt(prompt)

    def get_generate_tasks_prompt(
        self,
        task_type: TaskType,
    ) -> TextPrompt:
        r"""Gets the prompt for generating tasks for a given task type.

        Args:
            task_type (TaskType): The type of the task.

        Returns:
            TextPrompt: The generated prompt for generating tasks.
        """
        return self.get_prompt_from_key(task_type, "generate_tasks")

    def get_task_specify_prompt(
        self,
        task_type: TaskType,
    ) -> TextPrompt:
        r"""Gets the prompt for specifying a task for a given task type.

        Args:
            task_type (TaskType): The type of the task.

        Returns:
            TextPrompt: The generated prompt for specifying a task.
        """
        return self.get_prompt_from_key(task_type, "task_specify_prompt")