PyTorch RNN

Free Machine Learning courses with 130+ real-time projects Start Now!!

We have all wondered how Google knows what you want to search for even before you complete the sentence. Or how the big companies know the people’s opinions about something only with some reviews and social media posts. Or how do the translators work? The answer is Recurrent Neural Networks. Recurrent Neural Network, often abbreviated as RNN, is widely used in sentence completion, sentiment analysis and translation tasks. It is the go-to algorithm for problems involving sequences. Let’s learn more about PyTorch RNN.

What is a Recurrent Neural Network?

The recurrent neural networks take in an input and train themselves in such a way that they can predict the next state of the network. As we know, a neural network has three layers – the input, output and hidden layer. The hidden layers take in the previous state along with the current state as its input. So, the present state depends on the current and previous input.

what is a recurrent neural network

Why not Feed Forward neural Network?

Feed Forward networks cannot relate two subsequent outputs. For example, suppose we have a classification task of differentiating cats from dogs. In that case, the Feed Forward network simply learns to understand the information about the present data and has nothing to do with the previous predictions. However, when we write a sentence, it has a structure that makes sense. This structure depends on the last word we wrote; this is where the Feed Forward network fails.

Working of a Recurrent Neural Network:

We aim to map a sequence to a vector called encoding, and then decode it. To understand the working of an RNN, let’s take an example. Suppose we have a set of sentences.

<start> I have a blue car </end>

<start> We live in a blue house </end>

<start> I have a car battery. </end>

To make things simple, let’s assume the current word depends only on the previous word and not on the words before that.

When we pass the above sentences as input, the recurrent neural network computes the probability of a word appearing when given the previous word.

All the words in the above example are: {I, have, a, blue, car, we, live, in, house, and battery.}

The probability of a word appearing when given the previous word is represented in the table below.

Probability of appearing next when the last word is known.

Last wordIhaveabluecarweliveinhousebattery<end>
<start>2/300001/300000
I01000000000
have00100000000
a0002/31/3000000
blue00001/20001/200
car0000000001/21/2
we00000010000
live00000001000
in00100000000
house00000000001
battery00000000001

This was a simple example; therefore, most blocks have 0. However, in practical applications, most blocks are filled with some number representing the probability of a word appearing next. Also, in practical applications, these probabilities depend on the previous word and the words before it. Therefore, instead of words, we use the term state.

Types of Recurrent Neural Networks (RNN):

a. One to One

The Recurrent Neural Networks having only one input and one output are called One to One neural networks.

rnn one to one

b. One to Many

Recurrent Neural Networks having one input and multiple outputs are said to be One to Many recurrent neural networks. These kinds of networks are used for applications requiring different aspects of a single input, such as image captioning.

rnn one to many

c. Many to One

When there are many inputs and one output is to be obtained, the recurrent network is called Many to One. In use cases such as sentiment analysis, several input parameters are given, and only one output is desired, which learns the characteristics of the input given and produces a single output.

rnn many to one

d. Many to Many

A Many to Many Recurrent Neural Networks may have any number of inputs and outputs. These inputs and outputs can be equal in number (Equal Unit Size), and they may also be unequal (Unequal Unit Size).

rnn many to many

Problems with RNN and their solution:

RNN depends on previous states; therefore, adjusting to extremely new information may become difficult. For example, in the above example, when the word ‘battery’ appears, the sentence ends, i.e. the probability of <end> appearing after battery is 1. Now, if some new sentence has any other word after battery, it would be challenging for our model to acknowledge it.

Basically, two types of problems can happen in such cases. The error either explodes or diminishes, with the new sentence having very little effect on our model.

When the error explodes, we can rectify it by clipping the gradient after it crosses a certain threshold. For diminishing gradients, we can use LSTM.

LSTMs are a special kind of neural network with a forget gate, input gate, output gate, memory cell etc. These gates are responsible for understanding the relevant information of the previous states and treating the new input accordingly.

Implementing RNN using PyTorch:

import torch
import torch.nn as nn

class RRNet(nn.Module):
   def __init__(self,input_dim,hidden_dim):
       super(RRNet,self).__init__()
       self.rnn_model=nn.RNN(input_dim,hidden_dim,1)
      
   def forward(self,data):
      
       return self.rnn_model(data)
  
  
      

rrnet=RRNet(5,1)
#We have passed the feature size and dimension of the hidden layer respectively 
data=torch.rand(5,4,5)
#For the sake of demonstration we have cerated a random tensor

print(data)

Output:

RNN data

result,h_s=rrnet.forward(data)

result

result

h_s

h_s

Summary

Recurrent Neural Networks can be used in every case where the previous state of the data plays a significant role in predicting the next word or state. PyTorch has an RNN module that helps us design such models conveniently. However, in some cases, the error may either vanish or explode. It can be resolved using several methods, such as gradient clipping, error truncating and LSTM.

Your opinion matters
Please write your valuable feedback about DataFlair on Google

courses

DataFlair Team

DataFlair Team creates expert-level guides on programming, Java, Python, C++, DSA, AI, ML, data Science, Android, Flutter, MERN, Web Development, and technology. Our goal is to empower learners with easy-to-understand content. Explore our resources for career growth and practical learning.

Leave a Reply

Your email address will not be published. Required fields are marked *