from abc import ABC, abstractmethod from typing import Dict, Any, List, Tuple, Optional, Union class AlpacaTemplate: DEFAULT_SYSTEM = ( "Below is an instruction that describes a task. " "Write a response that appropriately completes the request." ) @classmethod def template( cls, user_content: str, system_content: Union[str, None] = None, response: Union[str, None] = None, ) -> str: template: str = "" if system_content: template += f"{system_content}\n\n" else: template += f"{cls.DEFAULT_SYSTEM}\n\n" template += f"### Instruction:\n{user_content}\n\n### Response:\n" if response: template += response return template class SftPrompt(ABC): @classmethod @abstractmethod def prompt_user_content(cls, data: Dict[str, Any]) -> str: pass @classmethod @abstractmethod def prompt_system_content(cls, data: Dict[str, Any]) -> Optional[str]: pass @classmethod @abstractmethod def prompt_target(cls, sql: str) -> str: pass @classmethod def prompt(cls, data: Dict[str, Any]) -> Tuple[str, Optional[str], str]: user_content = cls.prompt_user_content(data) system_content = cls.prompt_system_content(data) if "spark_sql" in data: target = cls.prompt_target(data["spark_sql"]) else: target = None return user_content, system_content, target class SQLGeneratePrompt(SftPrompt): @classmethod def split_res( cls, res: Union[str, List[str], None] ) -> Union[str, List[Union[str, None]], None]: if res == None: return res if isinstance(res, list): for i in range(len(res)): res[i] = cls.split_res(res[i]) return res else: res = res.strip() if not (res.startswith("```sql") and res.endswith("```")): return None return res[6:-3].strip() @classmethod def compose_extra_task_desc( cls, hint: Union[str, None], related_question_sqls: Union[List[Dict[str, str]], None], ) -> str: if not hint and not related_question_sqls: return "" elif hint and not related_question_sqls: return ( "I provide generation hint, you can refer to it to help you generate.\n" ) elif not hint and related_question_sqls: return "I provide other user queries related to the user query with the corresponding Spark SQL queries, you can refer to them to help you generate.\n" else: return "I provide generation hint and other user queries related to the user query with the corresponding Spark SQL queries, you can refer to them to help you generate.\n" @classmethod def compose_hint_content(cls, hint: Union[str, None]) -> str: if hint == None: return "" if len(hint) == 0: return "" hint = hint.strip() content = ( "[BEGIN OF GENERATION HINT]\n" f"{hint}\n" "[END OF GENERATION HINT]\n" "\n" ) return content @classmethod def compose_related_question_sqls_content( cls, related_question_sqls: Union[List[Dict[str, str]], None] ) -> str: if related_question_sqls == None: return "" if len(related_question_sqls) == 0: return "" content = "[BEGIN OF RELATED QUERIES]\n" for i, question_sql in enumerate(related_question_sqls): question, sql = question_sql["question"], question_sql["spark_sql"] question = question.strip() question = question.replace("\n", " ") sql = sql.strip() sub_content = ( f"# Related Query {i+1}\n" "## User Query\n" f"`{question}`\n" "## Spark SQL Query\n" "```sql\n" f"{sql}\n" "```\n" "\n" ) content = content + sub_content content = content.strip() + "\n" content += "[END OF RELATED QUERIES]\n\n" return content @classmethod def extract_table_schema(cls, user_content: str) -> str: start_idx = user_content.find("[BEGIN OF TABLE SCHEMAS]") + len( "[BEGIN OF TABLE SCHEMAS]" ) end_idx = user_content.find("[END OF TABLE SCHEMAS]") return user_content[start_idx:end_idx].strip() @classmethod def extract_user_query(cls, user_content: str) -> str: start_idx = user_content.find("[BEGIN OF QUERY]\nUser Query: ") + len( "[BEGIN OF QUERY]\nUser Query: " ) end_idx = user_content.find("[END OF QUERY]") return user_content[start_idx:end_idx].strip() @classmethod def extract_hint(cls, user_content: str) -> Union[str, None]: start_idx = user_content.find("[BEGIN OF GENERATION HINT]") + len( "[BEGIN OF GENERATION HINT]" ) end_idx = user_content.find("[END OF GENERATION HINT]") if end_idx == -1: return None return user_content[start_idx:end_idx].strip() @classmethod def extract_related_question_sqls( cls, user_content: str ) -> Union[List[Dict[str, str]], None]: start_idx = user_content.find("[BEGIN OF RELATED QUERIES]") + len( "[BEGIN OF RELATED QUERIES]" ) end_idx = user_content.find("[END OF RELATED QUERIES]") if end_idx == -1: return None related_question_sqls = user_content[start_idx:end_idx].strip().split("\n\n") res = [] for question_sql in related_question_sqls: question_start_idx = question_sql.find("`") + 1 question_end_idx = question_sql.find("`\n##") question = question_sql[question_start_idx:question_end_idx] sql_start_idx = question_sql.find("```sql\n") + 7 sql_end_idx = -3 sql = question_sql[sql_start_idx:sql_end_idx] res.append({"question": question, "spark_sql": sql}) return res @classmethod def prompt_user_content(cls, data: Dict[str, Any]) -> str: question, schema = data["question"].strip(), data["schema"].strip() hint, related_question_sqls = ( data["hint"], data["related_question_sqls"], ) extra_task_desc = cls.compose_extra_task_desc(hint, related_question_sqls) hint_content = cls.compose_hint_content(hint) related_question_sqls_content = cls.compose_related_question_sqls_content( related_question_sqls ) user_content = ( "[BEGIN OF TASK INSTRUCTION]\n" "You are an expert in composing Spark SQL queries. You are given a user query and a set of table schemas.\n" "Based on the user query, you need to generate one Spark SQL query to achieve the purpose.\n" f"{extra_task_desc}" "[END OF TASK INSTRUCTION]\n" "\n" "[BEGIN OF TABLE SCHEMAS]\n" f"{schema}\n" "[END OF TABLE SCHEMAS]\n" "\n" f"{hint_content}" f"{related_question_sqls_content}" "[BEGIN OF FORMAT INSTRUCTION]\n" "The output MUST strictly adhere to the following format, and NO other text MUST be included.\n" "```sql\n" "your output Spark SQL query\n" "```\n" "[END OF FORMAT INSTRUCTION]\n" "\n" "[BEGIN OF QUERY]\n" f"User Query: {question}\n" "[END OF QUERY]\n" ) user_content = AlpacaTemplate.template(user_content) return user_content @classmethod def prompt_system_content(cls, data: Dict[str, Any]) -> Union[str, None]: return None @classmethod def prompt_target(cls, sql: Union[str, List[str]]) -> Union[str, List[str]]: def map_func(tmp_sql: str) -> str: tmp_sql = tmp_sql.strip() tmp_sql = tmp_sql.strip(";") tmp_sql = tmp_sql.strip() tmp_sql += ";" return "```sql\n" f"{tmp_sql}\n" "```" if isinstance(sql, str): target = map_func(sql) else: target = [map_func(_sql) for _sql in sql] return target