File size: 2,618 Bytes
23add18
0b619bd
 
23add18
0b619bd
 
 
 
 
 
 
 
 
 
23add18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#!/usr/bin/env python
r"""_summary_
-*- coding: utf-8 -*-

Module : prompt.reader

File Name : reader.py

Description : Read prompt template

Creation Date : 2024-07-16

Author : Frank Kang([email protected])
"""
import xml.etree.ElementTree as ET
from xml.etree.ElementTree import Element


class Query(object):
    def __init__(self, query_node: Element) -> None:
        super(Query, self).__init__()
        self.rank = int(query_node.get('rank'))
        self.title = query_node.find('title').text
        self.text = query_node.find('text').text

    @staticmethod
    def Get_Title(query_node: Element) -> str:
        return query_node.find('title').text


class AssistantCreateQuery(Query):
    TITLE = 'System Message'

    def __init__(self, query_node: Element) -> None:
        super(AssistantCreateQuery, self).__init__(query_node)

    def __call__(self, name, *args, tools=[{"type": "code_interpreter"}], model="gpt-4-1106-preview", **kwds) -> dict:
        """Get parameters used for client.beta.assistants.create

        Returns:
            dict: parameters used for client.beta.assistants.create
        """
        return {'name': name, 'instructions': self.text.format(*args, **kwds), 'tools': tools, 'model': model}


class MessageQuery(Query):
    TITLE = 'User Message'

    def __init__(self, query_node: Element) -> None:
        super(MessageQuery, self).__init__(query_node)

    def __call__(self, *args, **kwds) -> dict:
        """Using like str.format

        Returns:
            dict: _description_
        """
        return {'role': 'user', 'content': self.text.format(*args, **kwds)}


class Prompt(object):
    def __init__(self, path) -> None:
        """Init Prompy by xml file

        Args:
            path (_type_): _description_
        """
        super(Prompt, self).__init__()
        self.path = path
        tree = ET.parse(path)
        body = tree.getroot()
        self.queries = {}
        for query in body.findall('query'):
            self.__read_query__(query)

    def __read_query__(self, query_node: Element):
        title = Query.Get_Title(query_node)
        query: Query
        if title == AssistantCreateQuery.TITLE:
            query = AssistantCreateQuery(query_node)
        elif title == MessageQuery.TITLE:
            query = MessageQuery(query_node)
        else:
            raise TypeError('Title not supported!')

        if query.rank not in self.queries:
            self.queries[query.rank] = [query]
        else:
            self.queries[query.rank].append(query)

    # def __call__(self, *args: ET.Any, **kwds: ET.Any) -> list: