File size: 2,618 Bytes
23add18 0b619bd 23add18 0b619bd 23add18 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
#!/usr/bin/env python
r"""_summary_
-*- coding: utf-8 -*-
Module : prompt.reader
File Name : reader.py
Description : Read prompt template
Creation Date : 2024-07-16
Author : Frank Kang([email protected])
"""
import xml.etree.ElementTree as ET
from xml.etree.ElementTree import Element
class Query(object):
def __init__(self, query_node: Element) -> None:
super(Query, self).__init__()
self.rank = int(query_node.get('rank'))
self.title = query_node.find('title').text
self.text = query_node.find('text').text
@staticmethod
def Get_Title(query_node: Element) -> str:
return query_node.find('title').text
class AssistantCreateQuery(Query):
TITLE = 'System Message'
def __init__(self, query_node: Element) -> None:
super(AssistantCreateQuery, self).__init__(query_node)
def __call__(self, name, *args, tools=[{"type": "code_interpreter"}], model="gpt-4-1106-preview", **kwds) -> dict:
"""Get parameters used for client.beta.assistants.create
Returns:
dict: parameters used for client.beta.assistants.create
"""
return {'name': name, 'instructions': self.text.format(*args, **kwds), 'tools': tools, 'model': model}
class MessageQuery(Query):
TITLE = 'User Message'
def __init__(self, query_node: Element) -> None:
super(MessageQuery, self).__init__(query_node)
def __call__(self, *args, **kwds) -> dict:
"""Using like str.format
Returns:
dict: _description_
"""
return {'role': 'user', 'content': self.text.format(*args, **kwds)}
class Prompt(object):
def __init__(self, path) -> None:
"""Init Prompy by xml file
Args:
path (_type_): _description_
"""
super(Prompt, self).__init__()
self.path = path
tree = ET.parse(path)
body = tree.getroot()
self.queries = {}
for query in body.findall('query'):
self.__read_query__(query)
def __read_query__(self, query_node: Element):
title = Query.Get_Title(query_node)
query: Query
if title == AssistantCreateQuery.TITLE:
query = AssistantCreateQuery(query_node)
elif title == MessageQuery.TITLE:
query = MessageQuery(query_node)
else:
raise TypeError('Title not supported!')
if query.rank not in self.queries:
self.queries[query.rank] = [query]
else:
self.queries[query.rank].append(query)
# def __call__(self, *args: ET.Any, **kwds: ET.Any) -> list:
|