Last Updated on 2021-08-25 by Clay
If we have a need to split our data set for deep learning, we can use PyTorch built-in data split function random_split()
to split our data for dataset.
The following I will introduce how to use random_split()
function.
random_split() Function
Sample Code
The method of use is very easy, taking the classic Mnist handwritten digit data set as an example.
# coding: utf-8 import torch.utils.data as data from torchvision import datasets def main(): # Load data set train_set = datasets.MNIST(root='MNIST', download=True, train=True) test_set = datasets.MNIST(root='MNIST', download=True, train=False) # Before print('Train data set:', len(train_set)) print('Test data set:', len(test_set)) # Random split train_set_size = int(len(train_set) * 0.8) valid_set_size = len(train_set) - train_set_size train_set, valid_set = data.random_split(train_set, [train_set_size, valid_set_size]) # After print('='*30) print('Train data set:', len(train_set)) print('Test data set:', len(test_set)) print('Valid data set:', len(valid_set)) if __name__ == '__main__': main()
Output:
Train data set: 60000
Test data set: 10000
==============================
Train data set: 48000
Test data set: 10000
Valid data set: 12000
As you can see, we just need to pass two arguments for random_split()
: dataset object and ratio of data splitting.
Fixed Random Seed
If we want to fixed the split result, we can write the following code in the head of program:
import torch torch.manual_seed(0)
References
- https://stackoverflow.com/questions/55820303/fixing-the-seed-for-torch-random-split
- https://discuss.pytorch.org/t/how-to-split-dataset-into-test-and-validation-sets/33987/2