save_splitting ¶
assert_valid_splits ¶
assert_valid_splits(
splits: list[list[list[int], list[int]]],
y: ndarray,
*,
non_empty: bool = True,
each_selected_class_in_each_split_subset: bool = True,
same_length_training_splits: bool = True
)
Verify that the splits are valid.
fix_split_by_dropping_classes ¶
fix_split_by_dropping_classes(
x: ndarray,
y: ndarray,
n_splits: int,
spliter_kwargs: dict,
) -> list[list[list[int], list[int]]]
Fixes stratifed splits for edge case.
For each class that has fewer instances than number of splits, we oversample before split to n_splits and then remove all oversamples and original samples from the splits; effectively removing the class from the data without touching the indices.
get_cv_split_for_data ¶
get_cv_split_for_data(
x: ndarray,
y: ndarray,
splits_seed: int,
n_splits: int,
*,
stratified_split: bool,
safety_shuffle: bool = True,
auto_fix_stratified_splits: bool = False,
force_same_length_training_splits: bool = False
) -> list[list[list[int], list[int]]] | str
Safety shuffle and generate (safe) splits.
If it returns str at the first entry, no valid split could be generated and the str is the reason why. Due to the safety shuffle, the original x and y are also returned and must be used.
Note: the function does not support repeated splits at this point. Simply call this function multiple times with different seeds to get repeated splits.
Test with:
if __name__ == "__main__":
print(
get_cv_split_for_data(
x=np.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]).T,
y=np.array([1, 1, 1, 2, 2, 2, 3, 3, 3, 4]),
splits_seed=42,
n_splits=3,
stratified_split=True,
auto_fix_stratified_splits=True,
)
)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
ndarray
|
The data to split. |
required |
y |
ndarray
|
The labels to split. |
required |
splits_seed |
int
|
The seed to use for the splits. Or a RandomState object. |
required |
n_splits |
int
|
The number of splits to generate. |
required |
stratified_split |
bool
|
Whether to use stratified splits. |
required |
safety_shuffle |
bool
|
Whether to shuffle the data before splitting. |
True
|
auto_fix_stratified_splits |
bool
|
Whether to try to fix stratified splits automatically. Fix by dropping classes with less than n_splits samples. |
False
|
force_same_length_training_splits |
bool
|
Whether to force the training splits to have the same amount of samples. Force by duplicating random instance in the training subset of a too small split until all training splits have the same amount of samples. |
False
|