Spaces:
Runtime error
Runtime error
import os | |
from dotenv import load_dotenv | |
import pandas as pd | |
import io | |
from langchain.output_parsers import PydanticOutputParser | |
from langchain.prompts import ChatPromptTemplate | |
from langchain.tools import PythonAstREPLTool | |
from langchain.chat_models import ChatOpenAI | |
from langchain.schema.output_parser import StrOutputParser | |
from langchain.chat_models import ChatOpenAI | |
from src.types import TableMapping | |
from src.vars import NUM_ROWS_TO_RETURN | |
from src.prompt import DATA_SCIENTIST_PROMPT_STR, SPEC_WRITER_PROMPT_STR, ENGINEER_PROMPT_STR | |
load_dotenv() | |
DATA_DIR_PATH = os.path.join(os.path.dirname(__file__), 'data') | |
SYNTHETIC_DATA_DIR_PATH = os.path.join(DATA_DIR_PATH, 'synthetic') | |
# TODO: consider different models for different prompts, e.g. natural language prompt might be better with higher temperature | |
BASE_MODEL = ChatOpenAI( | |
model_name='gpt-4', | |
temperature=0, | |
) | |
def _get_data_str_from_df_for_prompt(df, num_rows_to_return=NUM_ROWS_TO_RETURN): | |
return f'<df>\n{df.head(num_rows_to_return).to_markdown()}\n</df>' | |
def get_table_mapping(source_df, template_df): | |
'''Use PydanticOutputParser to parse the output of the Data Scientist prompt into a TableMapping object.''' | |
table_mapping_parser = PydanticOutputParser(pydantic_object=TableMapping) | |
analyst_prompt = ChatPromptTemplate.from_template( | |
template=DATA_SCIENTIST_PROMPT_STR, | |
partial_variables={'format_instructions': table_mapping_parser.get_format_instructions()}, | |
) | |
mapping_chain = analyst_prompt | BASE_MODEL | table_mapping_parser | |
table_mapping: TableMapping = mapping_chain.invoke({"source_1_csv_str": _get_data_str_from_df_for_prompt(source_df), "target_csv_str": _get_data_str_from_df_for_prompt(template_df)}) | |
return pd.DataFrame(table_mapping.dict()['table_mappings']) | |
def _sanitize_python_output(text: str): | |
'''Remove markdown from python code, as prompt returns it.''' | |
_, after = text.split("```python") | |
return after.split("```")[0] | |
def generate_mapping_code(table_mapping_df) -> str: | |
'''Chain two prompts together to generate python code from a table mapping: 1. technical spec writer, 2. python engineer''' | |
writer_prompt = ChatPromptTemplate.from_template(SPEC_WRITER_PROMPT_STR) | |
engineer_prompt = ChatPromptTemplate.from_template(ENGINEER_PROMPT_STR) | |
writer_chain = writer_prompt | BASE_MODEL | StrOutputParser() | |
engineer_chain = {"spec_str": writer_chain} | engineer_prompt | BASE_MODEL | StrOutputParser() | _sanitize_python_output | |
return engineer_chain.invoke({"table_mapping": str(table_mapping_df.to_dict())}) | |
def process_csv_text(temp_file): | |
'''Process a CSV file into a dataframe, either from a string or a file.''' | |
if isinstance(temp_file, str): | |
df = pd.read_csv(io.StringIO(temp_file)) | |
else: | |
df = pd.read_csv(temp_file.name) | |
return df | |
def transform_source(source_df, code_text: str): | |
'''Use PythonAstREPLTool to transform a source dataframe using python code.''' | |
return PythonAstREPLTool(locals={'source_df': source_df}).run(code_text) | |