File size: 4,343 Bytes
7885a28 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
from __future__ import annotations
from typing import TYPE_CHECKING
import numpy as np
from pandas.core.dtypes.missing import remove_na_arraylike
from pandas import (
MultiIndex,
concat,
)
from pandas.plotting._matplotlib.misc import unpack_single_str_list
if TYPE_CHECKING:
from collections.abc import Hashable
from pandas._typing import IndexLabel
from pandas import (
DataFrame,
Series,
)
def create_iter_data_given_by(
data: DataFrame, kind: str = "hist"
) -> dict[Hashable, DataFrame | Series]:
"""
Create data for iteration given `by` is assigned or not, and it is only
used in both hist and boxplot.
If `by` is assigned, return a dictionary of DataFrames in which the key of
dictionary is the values in groups.
If `by` is not assigned, return input as is, and this preserves current
status of iter_data.
Parameters
----------
data : reformatted grouped data from `_compute_plot_data` method.
kind : str, plot kind. This function is only used for `hist` and `box` plots.
Returns
-------
iter_data : DataFrame or Dictionary of DataFrames
Examples
--------
If `by` is assigned:
>>> import numpy as np
>>> tuples = [('h1', 'a'), ('h1', 'b'), ('h2', 'a'), ('h2', 'b')]
>>> mi = pd.MultiIndex.from_tuples(tuples)
>>> value = [[1, 3, np.nan, np.nan],
... [3, 4, np.nan, np.nan], [np.nan, np.nan, 5, 6]]
>>> data = pd.DataFrame(value, columns=mi)
>>> create_iter_data_given_by(data)
{'h1': h1
a b
0 1.0 3.0
1 3.0 4.0
2 NaN NaN, 'h2': h2
a b
0 NaN NaN
1 NaN NaN
2 5.0 6.0}
"""
# For `hist` plot, before transformation, the values in level 0 are values
# in groups and subplot titles, and later used for column subselection and
# iteration; For `box` plot, values in level 1 are column names to show,
# and are used for iteration and as subplots titles.
if kind == "hist":
level = 0
else:
level = 1
# Select sub-columns based on the value of level of MI, and if `by` is
# assigned, data must be a MI DataFrame
assert isinstance(data.columns, MultiIndex)
return {
col: data.loc[:, data.columns.get_level_values(level) == col]
for col in data.columns.levels[level]
}
def reconstruct_data_with_by(
data: DataFrame, by: IndexLabel, cols: IndexLabel
) -> DataFrame:
"""
Internal function to group data, and reassign multiindex column names onto the
result in order to let grouped data be used in _compute_plot_data method.
Parameters
----------
data : Original DataFrame to plot
by : grouped `by` parameter selected by users
cols : columns of data set (excluding columns used in `by`)
Returns
-------
Output is the reconstructed DataFrame with MultiIndex columns. The first level
of MI is unique values of groups, and second level of MI is the columns
selected by users.
Examples
--------
>>> d = {'h': ['h1', 'h1', 'h2'], 'a': [1, 3, 5], 'b': [3, 4, 6]}
>>> df = pd.DataFrame(d)
>>> reconstruct_data_with_by(df, by='h', cols=['a', 'b'])
h1 h2
a b a b
0 1.0 3.0 NaN NaN
1 3.0 4.0 NaN NaN
2 NaN NaN 5.0 6.0
"""
by_modified = unpack_single_str_list(by)
grouped = data.groupby(by_modified)
data_list = []
for key, group in grouped:
# error: List item 1 has incompatible type "Union[Hashable,
# Sequence[Hashable]]"; expected "Iterable[Hashable]"
columns = MultiIndex.from_product([[key], cols]) # type: ignore[list-item]
sub_group = group[cols]
sub_group.columns = columns
data_list.append(sub_group)
data = concat(data_list, axis=1)
return data
def reformat_hist_y_given_by(y: np.ndarray, by: IndexLabel | None) -> np.ndarray:
"""Internal function to reformat y given `by` is applied or not for hist plot.
If by is None, input y is 1-d with NaN removed; and if by is not None, groupby
will take place and input y is multi-dimensional array.
"""
if by is not None and len(y.shape) > 1:
return np.array([remove_na_arraylike(col) for col in y.T]).T
return remove_na_arraylike(y)
|