import enum from typing import Any import streamlit as st from core.constants import RECORD_SETS from core.query_params import set_tab from core.state import Field from core.state import Metadata import mlcroissant as mlc class ExtractType: """The type of extraction to perform.""" COLUMN = "Column" JSON_PATH = "JSON path" FILE_CONTENT = "File content" FILE_NAME = "File name" FILE_PATH = "File path" FILE_FULLPATH = "Full path" FILE_LINES = "Lines in file" FILE_LINE_NUMBERS = "Line numbers in file" class TransformType: """The type of transformation to perform.""" FORMAT = "Apply format" JSON_PATH = "Apply JSON path" REGEX = "Apply regular expression" REPLACE = "Replace" SEPARATOR = "Separator" def _get_source(source: mlc.Source | None, value: Any) -> mlc.Source: if not source: source = mlc.Source(extract=mlc.Extract()) if value == ExtractType.COLUMN: source.extract = mlc.Extract(column="") elif value == ExtractType.FILE_CONTENT: source.extract = mlc.Extract(file_property=mlc.FileProperty.content) elif value == ExtractType.FILE_NAME: source.extract = mlc.Extract(file_property=mlc.FileProperty.filename) elif value == ExtractType.FILE_PATH: source.extract = mlc.Extract(file_property=mlc.FileProperty.filepath) elif value == ExtractType.FILE_FULLPATH: source.extract = mlc.Extract(file_property=mlc.FileProperty.fullpath) elif value == ExtractType.FILE_LINES: source.extract = mlc.Extract(file_property=mlc.FileProperty.lines) elif value == ExtractType.FILE_LINE_NUMBERS: source.extract = mlc.Extract(file_property=mlc.FileProperty.lineNumbers) elif value == ExtractType.JSON_PATH: source.extract = mlc.Extract(json_path="") return source class FieldEvent(enum.Enum): """Event that triggers a field change.""" NAME = "NAME" DESCRIPTION = "DESCRIPTION" DATA_TYPE = "DATA_TYPE" SOURCE = "SOURCE" SOURCE_EXTRACT = "SOURCE_EXTRACT" SOURCE_EXTRACT_COLUMN = "SOURCE_EXTRACT_COLUMN" SOURCE_EXTRACT_JSON_PATH = "SOURCE_EXTRACT_JSON_PATH" TRANSFORM = "TRANSFORM" TRANSFORM_FORMAT = "TRANSFORM_FORMAT" REFERENCE = "REFERENCE" REFERENCE_EXTRACT = "REFERENCE_EXTRACT" REFERENCE_EXTRACT_COLUMN = "REFERENCE_EXTRACT_COLUMN" REFERENCE_EXTRACT_JSON_PATH = "REFERENCE_EXTRACT_JSON_PATH" def handle_field_change( change: FieldEvent, field: Field, key: str, **kwargs, ): set_tab(RECORD_SETS) value = st.session_state[key] if change == FieldEvent.NAME: old_name = field.name new_name = value if old_name != new_name: metadata: Metadata = st.session_state[Metadata] metadata.rename_field(old_name=old_name, new_name=new_name) field.name = value elif change == FieldEvent.DESCRIPTION: field.description = value elif change == FieldEvent.DATA_TYPE: field.data_types = [value] elif change == FieldEvent.SOURCE: node_type = "field" if "/" in value else "distribution" source = mlc.Source(uid=value, node_type=node_type) field.source = source elif change == FieldEvent.SOURCE_EXTRACT: source = field.source source = _get_source(source, value) field.source = source elif change == FieldEvent.SOURCE_EXTRACT_COLUMN: if not field.source: field.source = mlc.Source(extract=mlc.Extract()) field.source.extract = mlc.Extract(column=value) elif change == FieldEvent.SOURCE_EXTRACT_JSON_PATH: if not field.source: field.source = mlc.Source(extract=mlc.Extract()) field.source.extract = mlc.Extract(json_path=value) elif change == FieldEvent.TRANSFORM: number = kwargs.get("number") if number is not None and number < len(field.source.transforms): field.source.transforms[number] = mlc.Transform() elif change == TransformType.FORMAT: number = kwargs.get("number") if number is not None and number < len(field.source.transforms): field.source.transforms[number] = mlc.Transform(format=value) elif change == TransformType.JSON_PATH: number = kwargs.get("number") if number is not None and number < len(field.source.transforms): field.source.transforms[number] = mlc.Transform(json_path=value) elif change == TransformType.REGEX: number = kwargs.get("number") if number is not None and number < len(field.source.transforms): field.source.transforms[number] = mlc.Transform(regex=value) elif change == TransformType.REPLACE: number = kwargs.get("number") if number is not None and number < len(field.source.transforms): field.source.transforms[number] = mlc.Transform(replace=value) elif change == TransformType.SEPARATOR: number = kwargs.get("number") if number is not None and number < len(field.source.transforms): field.source.transforms[number] = mlc.Transform(separator=value) elif change == FieldEvent.REFERENCE: node_type = "field" if "/" in value else "distribution" source = mlc.Source(uid=value, node_type=node_type) field.references = source elif change == FieldEvent.REFERENCE_EXTRACT: source = field.references source = _get_source(source, value) field.references = source elif change == FieldEvent.REFERENCE_EXTRACT_COLUMN: if not field.references: field.references = mlc.Source(extract=mlc.Extract()) field.references.extract = mlc.Extract(column=value) elif change == FieldEvent.REFERENCE_EXTRACT_JSON_PATH: if not field.references: field.references = mlc.Source(extract=mlc.Extract()) field.references.extract = mlc.Extract(json_path=value)