Skip to content

[PyTorch] 使用 random_split() 函式切割資料集

在使用 PyTorch 搭建深度學習環境的時候,若是我們有著切割資料集的需求(比方說將訓練資料切出驗證資料),在將資料封裝成 PyTorch 的 dataset 物件時,我們便可以透過 PyTorch 內建的切割函式 —— random_split() 來做到切割資料集。

以下就來介紹如何使用 random_split() 這個函式。


random_split() 函式

範例程式碼

使用方法非常簡單,以經典的 Mnist 手寫數字資料集作為範例。

# coding: utf-8
import torch.utils.data as data
from torchvision import datasets


def main():
    # Load data set
    train_set = datasets.MNIST(root='MNIST', download=True, train=True)
    test_set = datasets.MNIST(root='MNIST', download=True, train=False)

    # Before
    print('Train data set:', len(train_set))
    print('Test data set:', len(test_set))

    # Random split
    train_set_size = int(len(train_set) * 0.8)
    valid_set_size = len(train_set) - train_set_size
    train_set, valid_set = data.random_split(train_set, [train_set_size, valid_set_size])
   
    # After
    print('='*30)
    print('Train data set:', len(train_set))
    print('Test data set:', len(test_set))
    print('Valid data set:', len(valid_set))


if __name__ == '__main__':
    main()



Output:

Train data set: 60000
Test data set: 10000
==============================
Train data set: 48000
Test data set: 10000
Valid data set: 12000

可以看到,random_split() 只需要輸入兩個參數: dataset 物件和切割資料的比例


固定亂數種子

random_split() 函式不像 scikit-learn 中的 train_test_split() 一樣可以直接設定亂數種子固定。如果要固定切割結果的話,需要在程式的開頭寫入:

import torch
torch.manual_seed(0)



References


Read More

2 thoughts on “[PyTorch] 使用 random_split() 函式切割資料集”

  1. 根據官方的說明書是可以設定random seed喔
    random_split(range(10), [3, 7], generator=torch.Generator().manual_seed(42))

Leave a Reply