zetavg
split finetune ui
a5e11b9 unverified
raw
history blame
2.76 kB
import json
from ...utils.data import get_dataset_content
from .values import (
default_dataset_plain_text_input_variables_separator,
default_dataset_plain_text_input_and_output_separator,
default_dataset_plain_text_data_separator,
)
def get_data_from_input(load_dataset_from, dataset_text, dataset_text_format,
dataset_plain_text_input_variables_separator,
dataset_plain_text_input_and_output_separator,
dataset_plain_text_data_separator,
dataset_from_data_dir, prompter):
if load_dataset_from == "Text Input":
if dataset_text_format == "JSON":
data = json.loads(dataset_text)
elif dataset_text_format == "JSON Lines":
lines = dataset_text.split('\n')
data = []
for i, line in enumerate(lines):
line_number = i + 1
try:
data.append(json.loads(line))
except Exception as e:
raise ValueError(
f"Error parsing JSON on line {line_number}: {e}")
else: # Plain Text
data = parse_plain_text_input(
dataset_text,
(
dataset_plain_text_input_variables_separator or
default_dataset_plain_text_input_variables_separator
).replace("\\n", "\n"),
(
dataset_plain_text_input_and_output_separator or
default_dataset_plain_text_input_and_output_separator
).replace("\\n", "\n"),
(
dataset_plain_text_data_separator or
default_dataset_plain_text_data_separator
).replace("\\n", "\n"),
prompter.get_variable_names()
)
else: # Load dataset from data directory
data = get_dataset_content(dataset_from_data_dir)
return data
def parse_plain_text_input(
value,
variables_separator, input_output_separator, data_separator,
variable_names
):
items = value.split(data_separator)
result = []
for item in items:
parts = item.split(input_output_separator)
variables = get_val_from_arr(parts, 0, "").split(variables_separator)
variables = [it.strip() for it in variables]
variables_dict = {name: var for name,
var in zip(variable_names, variables)}
output = get_val_from_arr(parts, 1, "").strip()
result.append({'variables': variables_dict, 'output': output})
return result
def get_val_from_arr(arr, index, default=None):
return arr[index] if -len(arr) <= index < len(arr) else default