Skip to content

save_splitting

assert_valid_splits

assert_valid_splits(
    splits: list[list[list[int], list[int]]],
    y: ndarray,
    *,
    non_empty: bool = True,
    each_selected_class_in_each_split_subset: bool = True,
    same_length_training_splits: bool = True
)

Verify that the splits are valid.

fix_split_by_dropping_classes

fix_split_by_dropping_classes(
    x: ndarray,
    y: ndarray,
    n_splits: int,
    spliter_kwargs: dict,
) -> list[list[list[int], list[int]]]

Fixes stratifed splits for edge case.

For each class that has fewer instances than number of splits, we oversample before split to n_splits and then remove all oversamples and original samples from the splits; effectively removing the class from the data without touching the indices.

get_cv_split_for_data

get_cv_split_for_data(
    x: ndarray,
    y: ndarray,
    splits_seed: int,
    n_splits: int,
    *,
    stratified_split: bool,
    safety_shuffle: bool = True,
    auto_fix_stratified_splits: bool = False,
    force_same_length_training_splits: bool = False
) -> list[list[list[int], list[int]]] | str

Safety shuffle and generate (safe) splits.

If it returns str at the first entry, no valid split could be generated and the str is the reason why. Due to the safety shuffle, the original x and y are also returned and must be used.

Note: the function does not support repeated splits at this point. Simply call this function multiple times with different seeds to get repeated splits.

Test with:

    if __name__ == "__main__":
    print(
        get_cv_split_for_data(
            x=np.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]).T,
            y=np.array([1, 1, 1, 2, 2, 2, 3, 3, 3, 4]),
            splits_seed=42,
            n_splits=3,
            stratified_split=True,
            auto_fix_stratified_splits=True,
        )
    )

Parameters:

Name Type Description Default
x ndarray

The data to split.

required
y ndarray

The labels to split.

required
splits_seed int

The seed to use for the splits. Or a RandomState object.

required
n_splits int

The number of splits to generate.

required
stratified_split bool

Whether to use stratified splits.

required
safety_shuffle bool

Whether to shuffle the data before splitting.

True
auto_fix_stratified_splits bool

Whether to try to fix stratified splits automatically. Fix by dropping classes with less than n_splits samples.

False
force_same_length_training_splits bool

Whether to force the training splits to have the same amount of samples. Force by duplicating random instance in the training subset of a too small split until all training splits have the same amount of samples.

False