Skip to content

[Scikit-Learn] 使用 train_test_split() 切割資料

Last Updated on 2021-09-06 by Clay

如果我們有『切資料』的需求 —— 比如說將資料切成 Training data (訓練資料) 以及 Test data (測試資料) ,我們便可以透過 Scikit-Learn 的 train_test_split() 這個函式來做到簡單的資料分割。

當然,你也可以使用 random 來自己完成這項工作,不過在 Python 中我們推崇的是『簡單、優雅』,而 train_test_split() 可以只使用『一行』來完成我們的需求。

當你有一個想要實驗的功能,赫然發現有個 Package 只需要呼叫一行便可以完成——再沒有什麼比這更棒的事了。

如果想要參閱 Scikit-Learn 關於 train_test_split() 函式的官方說明,可以參閱:

那麼,以下我就來簡單介紹怎麼使用這個方便的函式吧。


train_test_split() 的使用方法

如果是第一次使用 Scikit-Learn,需要用以下指令安裝:

pip3 install scikit-learn

安裝好後,我們就來試著 Demo 一下吧!首先,假設我們有以下這樣的資料:

data = [n for n in range(1, 11)]
print(data)


Output:

[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

然後我們使用以下指令 Import train_test_split() 這個函式到我們的程式當中:

from sklearn.model_selection import train_test_split



train_test_split() 所接受的變數其實非常單純,基本上為 3 項:『原始的資料』、『Seed』、『比例』

  • 原始的資料:就如同上方的 data 一般,是我們打算切成 Training data 以及 Test data 的原始資料
  • Seed: 亂數種子,可以固定我們切割資料的結果
  • 比例:可以設定 train_size 或 test_size,只要設定一邊即可,範圍在 [0-1] 之間

以下我們來簡單測試一下效果:

train_data, test_data = train_test_split(data, random_state=777, train_size=0.8)
print(train_data)
print(test_data)



Output:

[9, 4, 5, 6, 2, 10, 7, 8]
[3, 1]

可以看到 training data 跟 test data 的比例的確便是 8:2 ,符合我們設定的比例。

以上就是簡單的 train_test_split() 紀錄,切資料相當方便好用。


References

Leave a Reply