File size: 2,758 Bytes
a5e11b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import json
from ...utils.data import get_dataset_content

from .values import (
    default_dataset_plain_text_input_variables_separator,
    default_dataset_plain_text_input_and_output_separator,
    default_dataset_plain_text_data_separator,
)


def get_data_from_input(load_dataset_from, dataset_text, dataset_text_format,
                        dataset_plain_text_input_variables_separator,
                        dataset_plain_text_input_and_output_separator,
                        dataset_plain_text_data_separator,
                        dataset_from_data_dir, prompter):
    if load_dataset_from == "Text Input":
        if dataset_text_format == "JSON":
            data = json.loads(dataset_text)

        elif dataset_text_format == "JSON Lines":
            lines = dataset_text.split('\n')
            data = []
            for i, line in enumerate(lines):
                line_number = i + 1
                try:
                    data.append(json.loads(line))
                except Exception as e:
                    raise ValueError(
                        f"Error parsing JSON on line {line_number}: {e}")

        else:  # Plain Text
            data = parse_plain_text_input(
                dataset_text,
                (
                    dataset_plain_text_input_variables_separator or
                    default_dataset_plain_text_input_variables_separator
                ).replace("\\n", "\n"),
                (
                    dataset_plain_text_input_and_output_separator or
                    default_dataset_plain_text_input_and_output_separator
                ).replace("\\n", "\n"),
                (
                    dataset_plain_text_data_separator or
                    default_dataset_plain_text_data_separator
                ).replace("\\n", "\n"),
                prompter.get_variable_names()
            )

    else:  # Load dataset from data directory
        data = get_dataset_content(dataset_from_data_dir)

    return data


def parse_plain_text_input(
    value,
    variables_separator, input_output_separator, data_separator,
    variable_names
):
    items = value.split(data_separator)
    result = []
    for item in items:
        parts = item.split(input_output_separator)
        variables = get_val_from_arr(parts, 0, "").split(variables_separator)
        variables = [it.strip() for it in variables]
        variables_dict = {name: var for name,
                          var in zip(variable_names, variables)}
        output = get_val_from_arr(parts, 1, "").strip()
        result.append({'variables': variables_dict, 'output': output})
    return result


def get_val_from_arr(arr, index, default=None):
    return arr[index] if -len(arr) <= index < len(arr) else default