marcenacp's picture
Initial commit
cb5b71d
raw
history blame
4.52 kB
import dataclasses
import hashlib
import io
import tempfile
from etils import epath
import pandas as pd
import requests
from .names import find_unique_name
from .state import FileObject
from .state import FileSet
FILE_OBJECT = "File object"
FILE_SET = "File set"
RESOURCE_TYPES = [FILE_OBJECT, FILE_SET]
@dataclasses.dataclass
class FileType:
name: str
encoding_format: str
extensions: list[str]
class FileTypes:
CSV = FileType(name="CSV", encoding_format="text/csv", extensions=["csv"])
EXCEL = FileType(
name="Excel",
encoding_format="application/vnd.ms-excel",
extensions=["xls", "xlsx", "xlsm"],
)
JSON = FileType(
name="JSON", encoding_format="application/json", extensions=["json"]
)
JSONL = FileType(
name="JSON-Lines",
encoding_format="application/jsonl+json",
extensions=["jsonl"],
)
PARQUET = FileType(
name="Parquet",
encoding_format="application/vnd.apache.parquet",
extensions=["parquet"],
)
FILE_TYPES: dict[str, FileType] = {
file_type.name: file_type
for file_type in [
FileTypes.CSV,
FileTypes.EXCEL,
FileTypes.JSON,
FileTypes.JSONL,
FileTypes.PARQUET,
]
}
def _sha256(content: bytes):
"""Computes the sha256 digest of the byte string."""
return hashlib.sha256(content).hexdigest()
def hash_file_path(url: str) -> epath.Path:
"""Reproducibly produces the file path."""
tempdir = epath.Path(tempfile.gettempdir())
hash = _sha256(url.encode())
return tempdir / f"croissant-editor-{hash}"
def download_file(url: str, file_path: epath.Path):
"""Downloads the file locally to `file_path`."""
with requests.get(url, stream=True) as request:
request.raise_for_status()
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir = epath.Path(tmpdir) / "file"
with tmpdir.open("wb") as file:
for chunk in request.iter_content(chunk_size=8192):
file.write(chunk)
tmpdir.copy(file_path)
def get_dataframe(file_type: FileType, file: io.BytesIO | epath.Path) -> pd.DataFrame:
"""Gets the df associated to the file."""
if file_type == FileTypes.CSV:
return pd.read_csv(file)
elif file_type == FileTypes.EXCEL:
return pd.read_excel(file)
elif file_type == FileTypes.JSON:
return pd.read_json(file)
elif file_type == FileTypes.JSONL:
return pd.read_json(file, lines=True)
elif file_type == FileTypes.PARQUET:
return pd.read_parquet(file)
else:
raise NotImplementedError()
def file_from_url(file_type: FileType, url: str, names: set[str]) -> FileObject:
"""Downloads locally and extracts the file information."""
file_path = hash_file_path(url)
if not file_path.exists():
download_file(url, file_path)
with file_path.open("rb") as file:
sha256 = _sha256(file.read())
df = get_dataframe(file_type, file_path).infer_objects()
return FileObject(
name=find_unique_name(names, url.split("/")[-1]),
description="",
content_url=url,
encoding_format=file_type.encoding_format,
sha256=sha256,
df=df,
)
def file_from_upload(
file_type: FileType, file: io.BytesIO, names: set[str]
) -> FileObject:
"""Uploads locally and extracts the file information."""
sha256 = _sha256(file.getvalue())
df = get_dataframe(file_type, file).infer_objects()
return FileObject(
name=find_unique_name(names, file.name),
description="",
content_url=f"data/{file.name}",
encoding_format=file_type.encoding_format,
sha256=sha256,
df=df,
)
def file_from_form(
file_type: FileType, type: str, name, description, sha256: str, names: set[str]
) -> FileObject | FileSet:
"""Creates a file based on manually added fields."""
if type == FILE_OBJECT:
return FileObject(
name=find_unique_name(names, name),
description=description,
content_url="",
encoding_format=file_type.encoding_format,
sha256=sha256,
df=None,
)
elif type == FILE_SET:
return FileSet(
name=find_unique_name(names, name),
description=description,
encoding_format=file_type.encoding_format,
)
else:
raise ValueError("type has to be one of FILE_OBJECT, FILE_SET")