marcenacp's picture
Initial commit
cb5b71d
raw
history blame
5.82 kB
import enum
from typing import Any
import streamlit as st
from core.state import Field
from core.state import Metadata
import mlcroissant as mlc
class ExtractType:
"""The type of extraction to perform."""
COLUMN = "Column"
JSON_PATH = "JSON path"
FILE_CONTENT = "File content"
FILE_NAME = "File name"
FILE_PATH = "File path"
FILE_FULLPATH = "Full path"
FILE_LINES = "Lines in file"
FILE_LINE_NUMBERS = "Line numbers in file"
class TransformType:
"""The type of transformation to perform."""
FORMAT = "Apply format"
JSON_PATH = "Apply JSON path"
REGEX = "Apply regular expression"
REPLACE = "Replace"
SEPARATOR = "Separator"
def _get_source(source: mlc.Source | None, value: Any) -> mlc.Source:
if not source:
source = mlc.Source(extract=mlc.Extract())
if value == ExtractType.COLUMN:
source.extract = mlc.Extract(column="")
elif value == ExtractType.FILE_CONTENT:
source.extract = mlc.Extract(file_property=mlc.FileProperty.content)
elif value == ExtractType.FILE_NAME:
source.extract = mlc.Extract(file_property=mlc.FileProperty.filename)
elif value == ExtractType.FILE_PATH:
source.extract = mlc.Extract(file_property=mlc.FileProperty.filepath)
elif value == ExtractType.FILE_FULLPATH:
source.extract = mlc.Extract(file_property=mlc.FileProperty.fullpath)
elif value == ExtractType.FILE_LINES:
source.extract = mlc.Extract(file_property=mlc.FileProperty.lines)
elif value == ExtractType.FILE_LINE_NUMBERS:
source.extract = mlc.Extract(file_property=mlc.FileProperty.lineNumbers)
elif value == ExtractType.JSON_PATH:
source.extract = mlc.Extract(json_path="")
return source
class FieldEvent(enum.Enum):
"""Event that triggers a field change."""
NAME = "NAME"
DESCRIPTION = "DESCRIPTION"
DATA_TYPE = "DATA_TYPE"
SOURCE = "SOURCE"
SOURCE_EXTRACT = "SOURCE_EXTRACT"
SOURCE_EXTRACT_COLUMN = "SOURCE_EXTRACT_COLUMN"
SOURCE_EXTRACT_JSON_PATH = "SOURCE_EXTRACT_JSON_PATH"
TRANSFORM = "TRANSFORM"
TRANSFORM_FORMAT = "TRANSFORM_FORMAT"
REFERENCE = "REFERENCE"
REFERENCE_EXTRACT = "REFERENCE_EXTRACT"
REFERENCE_EXTRACT_COLUMN = "REFERENCE_EXTRACT_COLUMN"
REFERENCE_EXTRACT_JSON_PATH = "REFERENCE_EXTRACT_JSON_PATH"
def handle_field_change(
change: FieldEvent,
field: Field,
key: str,
**kwargs,
):
value = st.session_state[key]
if change == FieldEvent.NAME:
old_name = field.name
new_name = value
if old_name != new_name:
metadata: Metadata = st.session_state[Metadata]
metadata.rename_field(old_name=old_name, new_name=new_name)
field.name = value
elif change == FieldEvent.DESCRIPTION:
field.description = value
elif change == FieldEvent.DATA_TYPE:
field.data_types = [value]
elif change == FieldEvent.SOURCE:
node_type = "field" if "/" in value else "distribution"
source = mlc.Source(uid=value, node_type=node_type)
field.source = source
elif change == FieldEvent.SOURCE_EXTRACT:
source = field.source
source = _get_source(source, value)
field.source = source
elif change == FieldEvent.SOURCE_EXTRACT_COLUMN:
if not field.source:
field.source = mlc.Source(extract=mlc.Extract())
field.source.extract = mlc.Extract(column=value)
elif change == FieldEvent.SOURCE_EXTRACT_JSON_PATH:
if not field.source:
field.source = mlc.Source(extract=mlc.Extract())
field.source.extract = mlc.Extract(json_path=value)
elif change == FieldEvent.TRANSFORM:
number = kwargs.get("number")
if number is not None and number < len(field.source.transforms):
field.source.transforms[number] = mlc.Transform()
elif change == TransformType.FORMAT:
number = kwargs.get("number")
if number is not None and number < len(field.source.transforms):
field.source.transforms[number] = mlc.Transform(format=value)
elif change == TransformType.JSON_PATH:
number = kwargs.get("number")
if number is not None and number < len(field.source.transforms):
field.source.transforms[number] = mlc.Transform(json_path=value)
elif change == TransformType.REGEX:
number = kwargs.get("number")
if number is not None and number < len(field.source.transforms):
field.source.transforms[number] = mlc.Transform(regex=value)
elif change == TransformType.REPLACE:
number = kwargs.get("number")
if number is not None and number < len(field.source.transforms):
field.source.transforms[number] = mlc.Transform(replace=value)
elif change == TransformType.SEPARATOR:
number = kwargs.get("number")
if number is not None and number < len(field.source.transforms):
field.source.transforms[number] = mlc.Transform(separator=value)
elif change == FieldEvent.REFERENCE:
node_type = "field" if "/" in value else "distribution"
source = mlc.Source(uid=value, node_type=node_type)
field.references = source
elif change == FieldEvent.REFERENCE_EXTRACT:
source = field.references
source = _get_source(source, value)
field.references = source
elif change == FieldEvent.REFERENCE_EXTRACT_COLUMN:
if not field.references:
field.references = mlc.Source(extract=mlc.Extract())
field.references.extract = mlc.Extract(column=value)
elif change == FieldEvent.REFERENCE_EXTRACT_JSON_PATH:
if not field.references:
field.references = mlc.Source(extract=mlc.Extract())
field.references.extract = mlc.Extract(json_path=value)