Skip to content

[PyTorch] Use "random_split()" Function To Split Data Set

Last Updated on 2021-08-25 by Clay

If we have a need to split our data set for deep learning, we can use PyTorch built-in data split function random_split() to split our data for dataset.

The following I will introduce how to use random_split() function.


random_split() Function

Sample Code

The method of use is very easy, taking the classic Mnist handwritten digit data set as an example.

# coding: utf-8
import torch.utils.data as data
from torchvision import datasets


def main():
    # Load data set
    train_set = datasets.MNIST(root='MNIST', download=True, train=True)
    test_set = datasets.MNIST(root='MNIST', download=True, train=False)

    # Before
    print('Train data set:', len(train_set))
    print('Test data set:', len(test_set))

    # Random split
    train_set_size = int(len(train_set) * 0.8)
    valid_set_size = len(train_set) - train_set_size
    train_set, valid_set = data.random_split(train_set, [train_set_size, valid_set_size])
   
    # After
    print('='*30)
    print('Train data set:', len(train_set))
    print('Test data set:', len(test_set))
    print('Valid data set:', len(valid_set))


if __name__ == '__main__':
    main()



Output:

Train data set: 60000
Test data set: 10000
==============================
Train data set: 48000
Test data set: 10000
Valid data set: 12000

As you can see, we just need to pass two arguments for random_split(): dataset object and ratio of data splitting.


Fixed Random Seed

If we want to fixed the split result, we can write the following code in the head of program:

import torch
torch.manual_seed(0)



References


Read More

Leave a Reply