Spaces:
Running
Running
import enum | |
from typing import Any | |
import streamlit as st | |
from core.state import Field | |
from core.state import RecordSet | |
from events.fields import ExtractType | |
from events.fields import FieldEvent | |
from events.fields import handle_field_change | |
from events.fields import TransformType | |
import mlcroissant as mlc | |
from utils import needed_field | |
class SourceType: | |
"""The type of the source (distribution or field).""" | |
DISTRIBUTION = "distribution" | |
FIELD = "field" | |
EXTRACT_TYPES = [ | |
ExtractType.COLUMN, | |
ExtractType.JSON_PATH, | |
ExtractType.FILE_CONTENT, | |
ExtractType.FILE_NAME, | |
ExtractType.FILE_PATH, | |
ExtractType.FILE_FULLPATH, | |
ExtractType.FILE_LINES, | |
ExtractType.FILE_LINE_NUMBERS, | |
] | |
TRANSFORM_TYPES = [ | |
TransformType.FORMAT, | |
TransformType.JSON_PATH, | |
TransformType.REGEX, | |
TransformType.REPLACE, | |
TransformType.SEPARATOR, | |
] | |
def _get_extract(source: mlc.Source) -> str | None: | |
if source.extract.column: | |
return ExtractType.COLUMN | |
elif source.extract.file_property: | |
file_property = source.extract.file_property | |
if file_property == mlc.FileProperty.content: | |
return ExtractType.FILE_CONTENT | |
elif file_property == mlc.FileProperty.filename: | |
return ExtractType.FILE_NAME | |
elif file_property == mlc.FileProperty.filepath: | |
return ExtractType.FILE_PATH | |
elif file_property == mlc.FileProperty.fullpath: | |
return ExtractType.FILE_FULLPATH | |
elif file_property == mlc.FileProperty.lines: | |
return ExtractType.FILE_LINES | |
elif file_property == mlc.FileProperty.lineNumbers: | |
return ExtractType.FILE_LINE_NUMBERS | |
else: | |
return None | |
elif source.extract.json_path: | |
return ExtractType.JSON_PATH | |
return None | |
def _get_extract_index(source: mlc.Source) -> int | None: | |
extract = _get_extract(source) | |
if extract in EXTRACT_TYPES: | |
return EXTRACT_TYPES.index(extract) | |
return None | |
def _get_transforms(source: mlc.Source) -> list[str]: | |
transforms = source.transforms | |
return [_get_transform(transform) for transform in transforms] | |
def _get_transform(transform: mlc.Transform) -> str | None: | |
if transform.format: | |
return TransformType.FORMAT | |
elif transform.json_path: | |
return TransformType.JSON_PATH | |
elif transform.regex: | |
return TransformType.REGEX | |
elif transform.replace: | |
return TransformType.REPLACE | |
elif transform.separator: | |
return TransformType.SEPARATOR | |
return None | |
def _get_transforms_indices(source: mlc.Source) -> list[int]: | |
transforms = _get_transforms(source) | |
return [ | |
TRANSFORM_TYPES.index(transform) if transform in TRANSFORM_TYPES else None | |
for transform in transforms | |
] | |
def _handle_remove_reference(field): | |
"""Removes the reference from a field.""" | |
field.references = mlc.Source() | |
def render_source( | |
record_set_key: int, | |
record_set: RecordSet, | |
field: Field, | |
field_key: int, | |
possible_sources: list[str], | |
): | |
"""Renders the form for the source.""" | |
source = field.source | |
prefix = f"source-{record_set.name}-{field.name}" | |
col1, col2, col3 = st.columns([1, 1, 1]) | |
index = ( | |
possible_sources.index(source.uid) if source.uid in possible_sources else None | |
) | |
key = f"{prefix}-source" | |
col1.selectbox( | |
needed_field("Source"), | |
index=index, | |
options=[s for s in possible_sources if not s.startswith(record_set.name)], | |
key=key, | |
on_change=handle_field_change, | |
args=(FieldEvent.SOURCE, field, key), | |
) | |
if source.node_type == "distribution": | |
extract = col2.selectbox( | |
needed_field("Extract"), | |
index=_get_extract_index(source), | |
key=f"{prefix}-extract", | |
options=EXTRACT_TYPES, | |
on_change=handle_field_change, | |
args=(FieldEvent.SOURCE_EXTRACT, field, key), | |
) | |
if extract == ExtractType.COLUMN: | |
key = f"{prefix}-columnname" | |
col3.text_input( | |
needed_field("Column name"), | |
value=source.extract.column, | |
key=key, | |
on_change=handle_field_change, | |
args=(FieldEvent.SOURCE_EXTRACT_COLUMN, field, key), | |
) | |
if extract == ExtractType.JSON_PATH: | |
key = f"{prefix}-jsonpath" | |
col3.text_input( | |
needed_field("JSON path"), | |
value=source.extract.json_path, | |
key=key, | |
on_change=handle_field_change, | |
args=(FieldEvent.SOURCE_EXTRACT_JSON_PATH, field, key), | |
) | |
# Transforms | |
indices = _get_transforms_indices(field.source) | |
if source.transforms: | |
for number, (index, transform) in enumerate(zip(indices, source.transforms)): | |
_, col2, col3, col4 = st.columns([4.5, 4, 4, 1]) | |
key = f"{prefix}-{number}-transform" | |
selected = col2.selectbox( | |
"Transform", | |
index=index, | |
key=key, | |
options=TRANSFORM_TYPES, | |
on_change=handle_field_change, | |
args=(FieldEvent.TRANSFORM, field, key), | |
kwargs={"number": number}, | |
) | |
if selected == TransformType.FORMAT: | |
key = f"{prefix}-{number}-transform-format" | |
col3.text_input( | |
needed_field("Format"), | |
value=transform.format, | |
key=key, | |
on_change=handle_field_change, | |
args=(selected, field, key), | |
kwargs={"number": number, "type": "format"}, | |
) | |
elif selected == TransformType.JSON_PATH: | |
key = f"{prefix}-{number}-jsonpath" | |
col3.text_input( | |
needed_field("JSON path"), | |
value=transform.json_path, | |
key=key, | |
on_change=handle_field_change, | |
args=(selected, field, key), | |
kwargs={"number": number, "type": "format"}, | |
) | |
elif selected == TransformType.REGEX: | |
key = f"{prefix}-{number}-regex" | |
col3.text_input( | |
needed_field("Regular expression"), | |
value=transform.regex, | |
key=key, | |
on_change=handle_field_change, | |
args=(selected, field, key), | |
kwargs={"number": number, "type": "format"}, | |
) | |
elif selected == TransformType.REPLACE: | |
key = f"{prefix}-{number}-replace" | |
col3.text_input( | |
needed_field("Replace pattern"), | |
value=transform.replace, | |
key=key, | |
on_change=handle_field_change, | |
args=(selected, field, key), | |
kwargs={"number": number, "type": "format"}, | |
) | |
elif selected == TransformType.SEPARATOR: | |
key = f"{prefix}-{number}-separator" | |
col3.text_input( | |
needed_field("Separator"), | |
value=transform.separator, | |
key=key, | |
on_change=handle_field_change, | |
args=(selected, field, key), | |
kwargs={"number": number, "type": "format"}, | |
) | |
def _handle_remove_transform(field, number): | |
del field.source.transforms[number] | |
col4.button( | |
"✖️", | |
key=f"{prefix}-{number}-remove-transform", | |
on_click=_handle_remove_transform, | |
args=(field, number), | |
) | |
def _handle_add_transform(field): | |
if not field.source: | |
field.source = mlc.Source(transforms=[]) | |
field.source.transforms.append(mlc.Transform()) | |
col1, _, _ = st.columns([1, 1, 1]) | |
col1.button( | |
"Add transform on data", | |
key=f"{prefix}-close-fields", | |
on_click=_handle_add_transform, | |
args=(field,), | |
) | |
def render_references( | |
record_set_key: int, | |
record_set: RecordSet, | |
field: Field, | |
field_key: int, | |
possible_sources: list[str], | |
): | |
"""Renders the form for references.""" | |
key = f"references-{record_set.name}-{field.name}" | |
button_key = f"{key}-add-reference" | |
has_clicked_button = st.session_state.get(button_key) | |
references = field.references | |
if references or has_clicked_button: | |
col1, col2, col3, col4 = st.columns([4.5, 4, 4, 1]) | |
index = ( | |
possible_sources.index(references.uid) | |
if references.uid in possible_sources | |
else None | |
) | |
key = f"{key}-reference" | |
col1.selectbox( | |
"Reference", | |
index=index, | |
options=[s for s in possible_sources if not s.startswith(record_set.name)], | |
key=key, | |
on_change=handle_field_change, | |
args=(FieldEvent.REFERENCE, field, key), | |
) | |
if references.node_type == "distribution": | |
key = f"{key}-extract-references" | |
extract = col2.selectbox( | |
needed_field("Extract the reference"), | |
index=_get_extract_index(references), | |
key=key, | |
options=EXTRACT_TYPES, | |
on_change=handle_field_change, | |
args=(FieldEvent.REFERENCE_EXTRACT, field, key), | |
) | |
if extract == ExtractType.COLUMN: | |
key = f"{key}-columnname" | |
col3.text_input( | |
needed_field("Column name"), | |
value=references.extract.column, | |
key=key, | |
on_change=handle_field_change, | |
args=(FieldEvent.REFERENCE_EXTRACT_COLUMN, field, key), | |
) | |
if extract == ExtractType.JSON_PATH: | |
key = f"{key}-jsonpath" | |
col3.text_input( | |
needed_field("JSON path"), | |
value=references.extract.json_path, | |
key=key, | |
on_change=handle_field_change, | |
args=(FieldEvent.REFERENCE_EXTRACT_JSON_PATH, field, key), | |
) | |
col4.button( | |
"✖️", | |
key=f"{key}-remove-reference", | |
on_click=_handle_remove_reference, | |
args=(field,), | |
) | |
elif not has_clicked_button: | |
st.button( | |
"Add a join with another column/field", | |
key=button_key, | |
) | |