Spaces:
Running
Running
File size: 5,815 Bytes
cb5b71d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import enum
from typing import Any
import streamlit as st
from core.state import Field
from core.state import Metadata
import mlcroissant as mlc
class ExtractType:
"""The type of extraction to perform."""
COLUMN = "Column"
JSON_PATH = "JSON path"
FILE_CONTENT = "File content"
FILE_NAME = "File name"
FILE_PATH = "File path"
FILE_FULLPATH = "Full path"
FILE_LINES = "Lines in file"
FILE_LINE_NUMBERS = "Line numbers in file"
class TransformType:
"""The type of transformation to perform."""
FORMAT = "Apply format"
JSON_PATH = "Apply JSON path"
REGEX = "Apply regular expression"
REPLACE = "Replace"
SEPARATOR = "Separator"
def _get_source(source: mlc.Source | None, value: Any) -> mlc.Source:
if not source:
source = mlc.Source(extract=mlc.Extract())
if value == ExtractType.COLUMN:
source.extract = mlc.Extract(column="")
elif value == ExtractType.FILE_CONTENT:
source.extract = mlc.Extract(file_property=mlc.FileProperty.content)
elif value == ExtractType.FILE_NAME:
source.extract = mlc.Extract(file_property=mlc.FileProperty.filename)
elif value == ExtractType.FILE_PATH:
source.extract = mlc.Extract(file_property=mlc.FileProperty.filepath)
elif value == ExtractType.FILE_FULLPATH:
source.extract = mlc.Extract(file_property=mlc.FileProperty.fullpath)
elif value == ExtractType.FILE_LINES:
source.extract = mlc.Extract(file_property=mlc.FileProperty.lines)
elif value == ExtractType.FILE_LINE_NUMBERS:
source.extract = mlc.Extract(file_property=mlc.FileProperty.lineNumbers)
elif value == ExtractType.JSON_PATH:
source.extract = mlc.Extract(json_path="")
return source
class FieldEvent(enum.Enum):
"""Event that triggers a field change."""
NAME = "NAME"
DESCRIPTION = "DESCRIPTION"
DATA_TYPE = "DATA_TYPE"
SOURCE = "SOURCE"
SOURCE_EXTRACT = "SOURCE_EXTRACT"
SOURCE_EXTRACT_COLUMN = "SOURCE_EXTRACT_COLUMN"
SOURCE_EXTRACT_JSON_PATH = "SOURCE_EXTRACT_JSON_PATH"
TRANSFORM = "TRANSFORM"
TRANSFORM_FORMAT = "TRANSFORM_FORMAT"
REFERENCE = "REFERENCE"
REFERENCE_EXTRACT = "REFERENCE_EXTRACT"
REFERENCE_EXTRACT_COLUMN = "REFERENCE_EXTRACT_COLUMN"
REFERENCE_EXTRACT_JSON_PATH = "REFERENCE_EXTRACT_JSON_PATH"
def handle_field_change(
change: FieldEvent,
field: Field,
key: str,
**kwargs,
):
value = st.session_state[key]
if change == FieldEvent.NAME:
old_name = field.name
new_name = value
if old_name != new_name:
metadata: Metadata = st.session_state[Metadata]
metadata.rename_field(old_name=old_name, new_name=new_name)
field.name = value
elif change == FieldEvent.DESCRIPTION:
field.description = value
elif change == FieldEvent.DATA_TYPE:
field.data_types = [value]
elif change == FieldEvent.SOURCE:
node_type = "field" if "/" in value else "distribution"
source = mlc.Source(uid=value, node_type=node_type)
field.source = source
elif change == FieldEvent.SOURCE_EXTRACT:
source = field.source
source = _get_source(source, value)
field.source = source
elif change == FieldEvent.SOURCE_EXTRACT_COLUMN:
if not field.source:
field.source = mlc.Source(extract=mlc.Extract())
field.source.extract = mlc.Extract(column=value)
elif change == FieldEvent.SOURCE_EXTRACT_JSON_PATH:
if not field.source:
field.source = mlc.Source(extract=mlc.Extract())
field.source.extract = mlc.Extract(json_path=value)
elif change == FieldEvent.TRANSFORM:
number = kwargs.get("number")
if number is not None and number < len(field.source.transforms):
field.source.transforms[number] = mlc.Transform()
elif change == TransformType.FORMAT:
number = kwargs.get("number")
if number is not None and number < len(field.source.transforms):
field.source.transforms[number] = mlc.Transform(format=value)
elif change == TransformType.JSON_PATH:
number = kwargs.get("number")
if number is not None and number < len(field.source.transforms):
field.source.transforms[number] = mlc.Transform(json_path=value)
elif change == TransformType.REGEX:
number = kwargs.get("number")
if number is not None and number < len(field.source.transforms):
field.source.transforms[number] = mlc.Transform(regex=value)
elif change == TransformType.REPLACE:
number = kwargs.get("number")
if number is not None and number < len(field.source.transforms):
field.source.transforms[number] = mlc.Transform(replace=value)
elif change == TransformType.SEPARATOR:
number = kwargs.get("number")
if number is not None and number < len(field.source.transforms):
field.source.transforms[number] = mlc.Transform(separator=value)
elif change == FieldEvent.REFERENCE:
node_type = "field" if "/" in value else "distribution"
source = mlc.Source(uid=value, node_type=node_type)
field.references = source
elif change == FieldEvent.REFERENCE_EXTRACT:
source = field.references
source = _get_source(source, value)
field.references = source
elif change == FieldEvent.REFERENCE_EXTRACT_COLUMN:
if not field.references:
field.references = mlc.Source(extract=mlc.Extract())
field.references.extract = mlc.Extract(column=value)
elif change == FieldEvent.REFERENCE_EXTRACT_JSON_PATH:
if not field.references:
field.references = mlc.Source(extract=mlc.Extract())
field.references.extract = mlc.Extract(json_path=value)
|