PyTorch Datasets and Dataloaders

Free Machine Learning courses with 130+ real-time projects Start Now!!

Datasets are the most important part of any deep learning algorithm. Most of the time of a model building process is consumed by data. Before feeding the data we have collected to the model, several operations like imputing the missing values, encoding the text data into numerical form, etc need to be performed so that our model can infer a meaningful conclusion. 

Processing of data may sometimes require a lot of code. Therefore, it is preferred to separate these codes from our model for better readability. Fortunately, PyTorch has got us covered. It provides two classes Datasets and Dataloaders which helps us use the data available efficiently. Datasets allow us to use data, pre-loaded or any other custom-made  data, and Dataloaders makes it convenient for us to access these data by wrapping an iterable around the data.

Creating a custom Dataset in PyTorch

PyTorch’s Dataset class enables us to make our own dataset inheriting it’s properties which makes referring to individual samples easy. We can then use Dataloaders to iterate through these datasets and train our model.

a. Importing the required libraries:

import torch
from torch.utils.data import Dataset,Dataloader

b. Creating our own dataset class:

We will create a class to construct our dataset which inherits from PyTorch’s Dataset class due to which we can perform any operation on the custom dataset with ease.

class DataFlair_dataset(Dataset):

Firstly we will define a constructor with default values to build our dataset.

i.  __init__

def __init__(self, length = 100, transform = None):
    	    self.len = length
    	    self.x = 2 * torch.ones(length, 2)
    	    self.y = torch.ones(length, 1)
    	    self.transform = transform

Now, we can define a getter method to retrieve the data required using proper indexing.

ii.__getitem__

def __getitem__(self, index):
    	    sample = self.x[index], self.y[index]
    	    if self.transform:
        	sample = self.transform(sample)	 
    	    return sample

iii.__len__

# Get Length
def __len__(self):
	    return self.len

Technology is evolving rapidly!
Stay updated with DataFlair on WhatsApp!!

Finally, an instance of our custom_dataset class can be created. 

our_dataset=DataFlair_dataset()

c. Printing our dataset:

To see if our dataset has been constructed or not, we will try to print first few samples of the dataset.

j=0
for i in our_dataset:
    print("x: ",i[0],"y: ",i[1])
    j+=1
    if j==5:
    	break

dataloaders printing dataset

d. Preprocessing the dataset using collate_fn:

Collate function is a preprocessing parameter that can be referenced to a function while loading a dataset using Dataloaders. To demonstrate the functionality of the collate_fn we will build a function that simply divides the value of x by 2 and for y it computes its modulus with 5.

 
def collate_fun(batch):
  for x,y in batch:
    x/=10
    y%=5
 
  return x,y

Once we have made the preprocessing required, we can now load the dataset using a dataloader and set the collate_fn parameter to the function we have built.

DLoader = DataLoader(our_dataset, batch_size=2, collate_fn=collate_fun)

 e. Printing the processed data using the dataloader

i=0
for data in DLoader:
  print(data)
  i+=1
  if i>5:
    break

dataloader preprocessed data

3. Using torchvision inbuilt Datasets and wrapping around an iterable using DataLoaders.

We will load the MNIST dataset and play around with it to see how to load a dataset. This dataset contains images of handwritten numbers which we can use to train a deep learning model which will be able to identify numbers in new images.

import torch
from torch.utils.data import Dataset
from torchvision import datasets,transforms
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt

a. Wrapping around an iterator around our dataset 

train=datasets.MNIST("", train=True,download=True,transform=transforms.Compose([transforms.ToTensor()]))
test=datasets.MNIST("", train=False,download=True,transform=transforms.Compose([transforms.ToTensor()]))

In the above command, we have loaded the MNIST dataset.

train=True/False– differentiates the training and test datasets,

download=True- download the dataset if it is not already available in the disk. transform=transforms.Compose([transforms.ToTensor()])– transforms the dataset into tensors so that it could be loaded on a gpu if needed.

trainset = torch.utils.data.DataLoader(train, batch_size=20, shuffle=True)
testset = torch.utils.data.DataLoader(test, batch_size=20, shuffle=False)

We are specifying how we are going to iterate over the dataset. Here, batch size is 20. This means we are passing only 20 samples at once. This facilitates generalisation of the model

for data in trainset:
    print(data)

dataloader loaded data

There are many more datasets available in torchvision like FashionMNIST, Caltech, Cityscapes etc. We can do the same operations we have done above on any of these datasets. Without Datasets and Dataloaders it would have taken us a few more lines of code to load, set batch size, convert the samples to tensors etc making our code complicated and difficult to read.

Summary

PyTorch does a great job in helping us in building our own datasets and also refer to it efficiently. With the added advantage of DataLoaders, a lot of our coding efforts can be saved and can even be more efficient.

Did you know we work 24x7 to provide you best tutorials
Please encourage us - write a review on Google

courses

DataFlair Team

DataFlair Team creates expert-level guides on programming, Java, Python, C++, DSA, AI, ML, data Science, Android, Flutter, MERN, Web Development, and technology. Our goal is to empower learners with easy-to-understand content. Explore our resources for career growth and practical learning.

Leave a Reply

Your email address will not be published. Required fields are marked *