본문 바로가기
Computer Science/[21-22] ML & DL

[cv] 시계열 데이터 Cross Validation - GroupKFold

by gojw 2022. 4. 16.

GroupKFold is a variation of k-fold which ensures that the same group is not represented in both testing and training sets. For example if the data is obtained from different subjects with several samples per-subject and if the model is flexible enough to learn from highly person specific features it could fail to generalize to new subjects. GroupKFold makes it possible to detect this kind of overfitting situations.

https://scikit-learn.org/stable/modules/cross_validation.html#group-k-fold

 

GroupKFold은 교차검증을 할 때 같은 group이 train, validation set에 동시에 이용되지 않게 만든다.

kf.split(X, y, groups) 로 (n_sample,) 형태의 groups 파라미터를 추가해주면 된다. 

같은 group이 train, validation set에 나눠 들어가지 않아서 각각의 set의 길이가 다를 수 있다.

만약 모든 split의 train, validation set 길이를 같게 해주고싶으면 StratifiedGroupKFold를 이용하면 된다.

 

kf = GroupKFold(n_splits=5)

for fold_idx, (train_idx, valid_idx) in enumerate(kf.split(X_train, y_train, groups.unique())):
    X_train, X_valid = X_train[train_idx], X_train[valid_idx]
    y_train, y_valid = y_train[train_idx], y_train[valid_idx]

 

시계열 데이터를 다룰 때에는 scikit learn의 TimeSeriesSplit도 있지만, 

하나의 group당 n개의 타임스텝이 있는 형태의 데이터셋을 이용할 때 GroupKFold를 사용했다.

 

예를들어 3개의 group(10, 11, 12)이 있고 2개의 타임스텝(0, 1)이 있는 경우

GroupKFold를 쓰면 validation set에는 train set과 같은 group을 가진 데이터들이 존재하지 않는다.

 

 

댓글