Skip to content

[Python] 使用 ShuffleSplit() 進行 cross-validation

Last Updated on 2021-07-24 by Clay

Cross-validation (交叉驗證) 是機器學習中『切割資料』的一個重要的觀念。簡單來說,當我們訓練一個模型時,我們通常會將資料分成『訓練資料』(Training data) 和『測試資料』(Test data),然後我們使用訓練資料訓練模型、並使用模型從來沒見過的測試資料評估模型的好壞。

但是如果我們在嘗試不同模型架構、參數時,我們一直維持著同樣的訓練資料和測試資料,我們很有可能就僅僅只是依據這種特定的資料切割再進行優化。實際上,當我們使用不同的訓練資料和測試資料時,可能這個模型就沒有那麼好的表現了。

過去我曾經介紹過使用 Scikit-Learn 的 train_test_split() 來切割訓練資料與測試資料,今天則是紀錄如何使用同樣為 Scikit-Learn 的 ShuffleSplit() 來快速切割出不同的資料組合。


簡單的 train_test_split

train_test_split() 其實已經相當好用了,事實上我們改變亂數表,幾乎可以如同 ShuffleSplit() 一般地使用 —— 不過 ShuffleSplit() 只需要幾行便可產生大量不同組合的資料。

以下是個簡單的 train_test_split() 範例:

# coding: utf-8
from sklearn.model_selection import train_test_split


# train_test_split
elements = list(range(10))
train_data, test_data = train_test_split(elements, train_size=0.8)
print('Train: {} Test: {}'.format(train_data, test_data))



Output:

Train: [5, 1, 8, 7, 4, 0, 9, 3] Test: [2, 6]

ShuffleSplit

以下則是 ShuffleSplit() 的參數:

  • n_splits (int, default=10): 產生的隨機資料組合數量
  • test_size: 測試資料比例,數值應介於 0.0 - 1.0 之間
  • train_size: 訓練資料比例,數值應介於 0.0 - 1.0 之間
  • random_state: 亂數種子

和 train_test_split() 一樣,test_sizetrain_size 只要設定其中一個即可,另外一個參數會自動幫忙計算。

以下為一個簡單的範例:

# coding: utf-8
from sklearn.model_selection import ShuffleSplit


# ShuffleSplit
elements = list(range(10))
rs = ShuffleSplit(n_splits=5, train_size=0.8)
for train_data, test_data in rs.split(elements):
    print('Train: {} Test: {}'.format(train_data, test_data))



Output:

Train: [9 5 2 4 3 8 7 6] Test: [1 0]
Train: [0 7 9 1 2 5 3 6] Test: [8 4]
Train: [2 7 4 6 1 5 9 0] Test: [8 3]
Train: [4 9 8 7 0 1 5 6] Test: [2 3]
Train: [9 7 0 6 8 1 2 3] Test: [4 5]

References


Read More

Leave a Reply