I have implemented a class in Python using Matplotlib for visualizing multiple subplots along with navigation buttons to move between different subsets of data. Each subplot displays a contour plot along with a colorbar.
While it kind of works, I am seeking advice on optimizing the code for better performance and efficiency. I want to improve the way subplots are updated and cleared when navigating between different subsets of data. I am open to suggestions. Any insights or alternative approaches to achieve the same functionality would be greatly appreciated.
import matplotlib.pyplot as plt
from matplotlib.widgets import Button
import numpy as np
class MultipleFigures:
def __init__(self, data):
self.indices = [0, 9] # 9 subplots
self.data = data
self.fig, self.ax = plt.subplots(3, 3, figsize=(15, 7))
self.ax = self.ax.ravel()
plt.subplots_adjust(wspace=0.5, hspace=0.5)
self.cb = list() # List to store colorbars
self.plot_initial_data() # Create initial plots
nextax = plt.axes([0.8, 0.02, 0.1, 0.04])
prevax = plt.axes([0.1, 0.02, 0.1, 0.04])
self.button_next = Button(nextax, 'Next')
self.button_prev = Button(prevax, 'Previous')
self.button_next.on_clicked(self.next)
self.button_prev.on_clicked(self.previous)
def plot_initial_data(self):
# Create initial plots for the first 9 subplots
self.temp = self.data[slice(*self.indices)]
for i, dataframe in enumerate(self.temp):
axes, colorbar = self.contourplot(dataframe, ax=self.ax[i]) # Plot contour plot and get axes and colorbar
self.cb.append(colorbar) # Store colorbar
self.clear_unused_plots(len(self.data)) # Hide unused plots
def next(self, event):
if self.indices[1] >= len(self.data): # If at the end of data, do nothing
return
else:
self._clear_previous_axes() # Clear previous plots
self.indices = [i + 9 for i in self.indices]
self.temp = self.data[self.indices[0]:self.indices[1]]
self.update_plots(self.temp)
self.clear_unused_plots(len(self.temp)) # Clear unused plots for the final rows
def previous(self, event):
if self.indices[0] == 0: # If at the beginning of data, do nothing
return
else:
self._clear_previous_axes() # Clear previous plots
self.indices = [i - 9 for i in self.indices]
self.temp = self.data[slice(*self.indices)]
self.update_plots(self.temp)
def update_plots(self, temp_df):
# Update plots for the current set of data
for i, dataframe in enumerate(temp_df):
self.ax[i].set_visible(True) # Make subplot visible
_, colorbar = self.contourplot(dataframe, ax=self.ax[i])
self.cb.append(colorbar)
self.fig.canvas.draw_idle()
def clear_unused_plots(self, num_plots_to_clear):
for i in range(num_plots_to_clear, len(self.ax)):
self.ax[i].clear()
self.ax[i].set_visible(False)
def _clear_previous_axes(self):
for i in range(len(self.temp)):
self.cb[i].remove()
self.ax[i].clear()
self.cb.clear()
@staticmethod
def contourplot(data, ax):
contour = ax.contourf(data)
cb = plt.colorbar(contour, ax=ax)
return ax, cb
%matplotlib qt
data = np.random.rand(20,10,10)
fl_ = MultipleFigures(data)
plt.show()