import os from dotenv import load_dotenv import pandas as pd import io from langchain.output_parsers import PydanticOutputParser from langchain.prompts import ChatPromptTemplate from langchain.tools import PythonAstREPLTool from langchain.chat_models import ChatOpenAI from langchain.schema.output_parser import StrOutputParser from langchain.chat_models import ChatOpenAI from src.types import TableMapping from src.prompt import ( DATA_SCIENTIST_PROMPT_STR, SPEC_WRITER_PROMPT_STR, ENGINEER_PROMPT_STR, ) load_dotenv() if os.environ.get("DEBUG") == "true": os.environ["LANGCHAIN_WANDB_TRACING"] = "true" os.environ["WANDB_PROJECT"] = "llm-data-mapper" NUM_ROWS_TO_RETURN = 5 DATA_DIR_PATH = os.path.join(os.path.dirname(__file__), "data") SYNTHETIC_DATA_DIR_PATH = os.path.join(DATA_DIR_PATH, "synthetic") # TODO: consider different models for different prompts, e.g. natural language prompt might be better with higher temperature BASE_MODEL = ChatOpenAI( model_name="gpt-4", temperature=0, ) def _get_data_str_from_df_for_prompt(df, num_rows_to_return=NUM_ROWS_TO_RETURN): return f"\n{df.head(num_rows_to_return).to_markdown()}\n" def get_table_mapping(source_df, template_df): """Use PydanticOutputParser to parse the output of the Data Scientist prompt into a TableMapping object.""" table_mapping_parser = PydanticOutputParser(pydantic_object=TableMapping) analyst_prompt = ChatPromptTemplate.from_template( template=DATA_SCIENTIST_PROMPT_STR, partial_variables={ "format_instructions": table_mapping_parser.get_format_instructions() }, ) mapping_chain = analyst_prompt | BASE_MODEL | table_mapping_parser table_mapping: TableMapping = mapping_chain.invoke( { "source_1_csv_str": _get_data_str_from_df_for_prompt(source_df), "target_csv_str": _get_data_str_from_df_for_prompt(template_df), } ) return pd.DataFrame(table_mapping.dict()["table_mappings"]) def _sanitize_python_output(text: str): """Remove markdown from python code, as prompt returns it.""" _, after = text.split("```python") return after.split("```")[0] def generate_mapping_code(table_mapping_df) -> str: """Chain two prompts together to generate python code from a table mapping: 1. technical spec writer, 2. python engineer""" writer_prompt = ChatPromptTemplate.from_template(SPEC_WRITER_PROMPT_STR) engineer_prompt = ChatPromptTemplate.from_template(ENGINEER_PROMPT_STR) writer_chain = writer_prompt | BASE_MODEL | StrOutputParser() engineer_chain = ( {"spec_str": writer_chain} | engineer_prompt | BASE_MODEL | StrOutputParser() | _sanitize_python_output ) return engineer_chain.invoke({"table_mapping": str(table_mapping_df.to_dict())}) def process_csv_text(value): """Process a CSV file into a dataframe, either from a string path or a file.""" if isinstance(value, str): df = pd.read_csv(value) else: df = pd.read_csv(value.name) return df def transform_source(source_df, code_text: str): """Use PythonAstREPLTool to transform a source dataframe using python code.""" return PythonAstREPLTool(locals={"source_df": source_df}).run(code_text)