How to Create Subplots in Matplotlib

Spread the love

There are various ways to create a subplot in Matplotlib which causes a lot of confusion among users. In this post, we will look at them one by one and try understand what they are doing and how to use them more efficiently.

1 . Stacking subplots in one direction –

When we stack subplots in one direction, the returned axes is a 1D numpy array containing the list of created axes. And for creating subplots, we can either use the Matlab style interface or object oriented interface.

Matlab Style Interface –

Matplotlib is inspired by Matlab. Initially it was created as a python alternative for the Matlab users. So many syntax and function available in pyplot resembles that.

Let’s read a dataset to work with.

import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

url = "https://raw.githubusercontent.com/bprasad26/lwd/master/data/ICICIBANK.NS.csv"
df = pd.read_csv(url, parse_dates=['Date'])
df.head()

Now suppose, you want to create two line charts on top of each other. To do that you have to use subplots in matplotlib.

# create a figure
plt.figure(figsize=(10, 8))

# create first subplot
plt.subplot(2, 1, 1)
plt.plot(df['Date'], df['High'],)
plt.ylabel("High Price")

# create second subplot
plt.subplot(2, 1, 2)
plt.plot(df['Date'], df['Low'], label = "Low")
plt.ylabel("Low Price")
plt.xlabel("Year")
plt.show()

Here, we first created a figure. Then we created the first subplot. And plt.subplot(2, 1, 1) means create subplots in a figure which has 2 rows and 1 column and this subplot is the 1st one out of the two. Next, we created the second subplot and used plt.show() to show the figure.

Object Oriented Interface –

If you want to create more complicated plots then you should choose the object oriented interface.

Let’s create the same plots using object oriented interface.

fig, ax = plt.subplots(2, figsize=(10, 8))

ax[0].plot(df['Date'], df['High'])
ax[0].set_ylabel("High Price")

ax[1].plot(df['Date'], df['Low'])
ax[1].set_xlabel("Year")
ax[1].set_ylabel("Low Price")
plt.show()

First we created a figure and axes and then we created each of the subplots. And if you look closely you can see that instead of using plt.xlabel() and plt.ylabel(), here we are using ax.set_xlabel() and ax.set_ylabel().

And if you are creating fewer axes then you can also unpack the axes like this

fig, (ax1, ax2) = plt.subplots(2, figsize=(10, 8))

ax1.plot(df['Date'], df['High'])
ax1.set_ylabel("High Price")

ax2.plot(df['Date'], df['Low'])
ax2.set_xlabel("Year")
ax2.set_ylabel("Low Price")
plt.show()

Now, let’s say that instead of plotting subplots on top of each others you want to create them side by side.

To create side-by-side subplots, we have to pass parameters (1, 2) for one row and two columns.

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12,8))

ax1.plot(df['Date'], df['High'])
ax1.set_ylabel("Price")
ax1.set_xlabel("Year")
ax1.set_title("High price")

ax2.plot(df['Date'], df['Low'])
ax2.set_xlabel("Year")
ax2.set_title("Low price")
plt.savefig("subplot4.png")
plt.show()

2 . Stacking subplots in two directions –

When stacking subplots in two directions, the returned axes is a 2D numpy array.

Let’s say we want to plot high, low, open and close prices together using subplots. To plot them we have to write.

fig, ax = plt.subplots(2, 2, sharex=True, sharey=True, figsize=(12, 12),)

# Low price
ax[0, 0].plot(df['Date'], df['Low'])
ax[0, 0].set_title("Low Price")
ax[0, 0].set_ylabel("Price")

# High Price
ax[0, 1].plot(df['Date'], df['High'])
ax[0, 1].set_title("High Price")

# Open Price
ax[1, 0].plot(df['Date'], df['Open'])
ax[1, 0].set_title("Open Price")
ax[1, 0].set_xlabel("Year")
ax[1, 0].set_ylabel("Price")

# Close price
ax[1, 1].plot(df['Date'], df['Close'])
ax[1, 1].set_title("Close Price")
ax[1, 1].set_xlabel("Year")
plt.show()

To remove clutter, I also added the sharex=True and sharey=True.

1 . Matplotlib – How to create a Line Chart in Python

Leave a Reply