File size: 3,490 Bytes
479f67b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#!/usr/bin/env python
r"""_summary_
-*- coding: utf-8 -*-

Module : prompt.data

File Name : data.py

Description : Read prompt template

Creation Date : 2024-07-16

Author : Frank Kang([email protected])
"""
from typing_extensions import override
import xml.etree.ElementTree as ET
from xml.etree.ElementTree import Element
from omegaconf import DictConfig
import os


class Trunk(DictConfig):
    def __init__(self, query_node: Element) -> None:
        super(Trunk, self).__init__(content={})
        for node in query_node:
            self[node.tag] = node.text


class Query():
    def __init__(self, query_node: Element) -> None:
        super(Query, self).__init__()
        self.rank = int(query_node.get('rank'))
        self.title = query_node.find('title').text
        self.text = query_node.find('text').text
        data = query_node.find('data')
        self.data = None

        if data is not None:
            self.data = [Trunk(trunk) for trunk in data.findall('trunk')]
            if len(self.data) == 0:
                self.data = None

    @staticmethod
    def Get_Title(query_node: Element) -> str:
        return query_node.find('title').text


class AssistantCreateQuery(Query):
    TITLE = 'System Message'

    def __init__(self, query_node: Element) -> None:
        super(AssistantCreateQuery, self).__init__(query_node)

    def __call__(self, *args,
                 name=None,
                 tools=[{"type": "code_interpreter"}],
                 model="gpt-4-1106-preview",
                 **kwds) -> dict:
        """Get parameters used for client.beta.assistants.create

        Returns:
            dict: parameters used for client.beta.assistants.create
        """
        return {'role': 'system', 'content': self.text.format(*args, **kwds)} if name is None else {'name': name, 'instructions': self.text.format(*args, **kwds), 'tools': tools, 'model': model}


class MessageQuery(Query):
    TITLE = 'User Message'

    def __init__(self, query_node: Element) -> None:
        super(MessageQuery, self).__init__(query_node)

    def __call__(self, *args, **kwds) -> dict:
        """Using like str.format

        Returns:
            dict: _description_
        """
        return {'role': 'user', 'content': self.text.format(*args, **kwds)}


class Prompt(object):
    def __init__(self, path) -> None:
        """Init Prompy by xml file

        Args:
            path (_type_): _description_
        """
        super(Prompt, self).__init__()
        self.path = path
        tree = ET.parse(path)
        body = tree.getroot()
        self.queries = {}
        self.name = '.'.join(os.path.basename(path).split('.')[:-1])
        for query in body.findall('query'):
            self.__read_query__(query)

    def __read_query__(self, query_node: Element):
        title = Query.Get_Title(query_node)
        query: Query
        if title == AssistantCreateQuery.TITLE:
            query = AssistantCreateQuery(query_node)
        elif title == MessageQuery.TITLE:
            query = MessageQuery(query_node)
        else:
            raise TypeError('Title not supported!')

        if query.rank not in self.queries:
            self.queries[query.rank] = [query]
        else:
            self.queries[query.rank].append(query)

    def __getitem__(self, rank):
        return self.queries[rank]

    @override
    def __repr__(self) -> str:
        return self.name
    # def __call__(self, *args: ET.Any, **kwds: ET.Any) -> list: