|
from langchain.prompts.chat import ChatMessagePromptTemplate |
|
|
|
|
|
class SpecialTokens: |
|
def __init__(self, config): |
|
self.user_token = config["user_token"] |
|
self.assistant_token = config["assistant_token"] |
|
self.system_token = config["system_token"] |
|
self.stop_token = config["stop_token"] |
|
|
|
|
|
def to_instruction(query, special_tokens): |
|
return special_tokens.user_token + query + special_tokens.stop_token |
|
|
|
|
|
def to_prompt(query, special_tokens): |
|
return ( |
|
special_tokens.user_token |
|
+ query |
|
+ special_tokens.stop_token |
|
+ special_tokens.assistant_token |
|
) |
|
|
|
|
|
def to_system(query, special_tokens): |
|
return special_tokens.system_token + query + special_tokens.stop_token |
|
|
|
|
|
def make_prompt(prompt, special_tokens): |
|
prompt_type = prompt["type"] |
|
if prompt_type == "system": |
|
return to_system("\n".join(prompt["prompt"]), special_tokens) |
|
elif prompt_type == "instruction": |
|
return to_instruction("\n".join(prompt["prompt"]), special_tokens) |
|
elif prompt_type == "prompt": |
|
return to_prompt("\n".join(prompt["prompt"]), special_tokens) |
|
else: |
|
return "Invalid prompt type, please check your config" |
|
|
|
|
|
def to_chat_instruction(query, special_tokens): |
|
return ChatMessagePromptTemplate.from_template( |
|
query, role=special_tokens.user_token |
|
) |
|
|
|
|
|
def to_chat_system(query, special_tokens): |
|
return ChatMessagePromptTemplate.from_template( |
|
query, role=special_tokens.system_token |
|
) |
|
|
|
|
|
def to_chat_prompt(query, special_tokens): |
|
return ChatMessagePromptTemplate.from_template( |
|
query, role=special_tokens.user_token |
|
) |
|
|
|
|
|
def make_chat_prompt(prompt, special_tokens): |
|
prompt_type = prompt["type"] |
|
if prompt_type == "system": |
|
return to_chat_system("\n".join(prompt["prompt"]), special_tokens) |
|
elif prompt_type == "instruction": |
|
return to_chat_instruction("\n".join(prompt["prompt"]), special_tokens) |
|
elif prompt_type == "prompt": |
|
return to_chat_prompt("\n".join(prompt["prompt"]), special_tokens) |
|
else: |
|
return "Invalid prompt type, please check your config" |
|
|