Last Updated on 2021-08-25 by Clay
在使用 PyTorch 搭建深度學習環境的時候,若是我們有著切割資料集的需求(比方說將訓練資料切出驗證資料),在將資料封裝成 PyTorch 的 dataset 物件時,我們便可以透過 PyTorch 內建的切割函式 —— random_split()
來做到切割資料集。
以下就來介紹如何使用 random_split()
這個函式。
random_split() 函式
範例程式碼
使用方法非常簡單,以經典的 Mnist 手寫數字資料集作為範例。
# coding: utf-8 import torch.utils.data as data from torchvision import datasets def main(): # Load data set train_set = datasets.MNIST(root='MNIST', download=True, train=True) test_set = datasets.MNIST(root='MNIST', download=True, train=False) # Before print('Train data set:', len(train_set)) print('Test data set:', len(test_set)) # Random split train_set_size = int(len(train_set) * 0.8) valid_set_size = len(train_set) - train_set_size train_set, valid_set = data.random_split(train_set, [train_set_size, valid_set_size]) # After print('='*30) print('Train data set:', len(train_set)) print('Test data set:', len(test_set)) print('Valid data set:', len(valid_set)) if __name__ == '__main__': main()
Output:
Train data set: 60000
Test data set: 10000
==============================
Train data set: 48000
Test data set: 10000
Valid data set: 12000
可以看到,random_split()
只需要輸入兩個參數: dataset 物件和切割資料的比例。
固定亂數種子
random_split()
函式不像 scikit-learn 中的 train_test_split()
一樣可以直接設定亂數種子固定。如果要固定切割結果的話,需要在程式的開頭寫入:
import torch torch.manual_seed(0)
References
- https://stackoverflow.com/questions/55820303/fixing-the-seed-for-torch-random-split
- https://discuss.pytorch.org/t/how-to-split-dataset-into-test-and-validation-sets/33987/2
根據官方的說明書是可以設定random seed喔
random_split(range(10), [3, 7], generator=torch.Generator().manual_seed(42))
哈哈,慚愧,是我沒詳細閱讀官方說明。
大家請勿模仿。
感謝告知!