89

In Python and Matplotlib, it is easy to either display the plot as a popup window or save the plot as a PNG file. How can I instead save the plot to a numpy array in RGB format?

1
  • All the answers below don't offer the full solution - firstly one needs to prevent pyplot displaying when plotting it - which is the step before figure can be aqcuired Commented Jan 16, 2024 at 19:07

10 Answers 10

112

This is a handy trick for unit tests and the like, when you need to do a pixel-to-pixel comparison with a saved plot.

One way is to use fig.canvas.tostring_rgb and then numpy.fromstring with the approriate dtype. There are other ways as well, but this is the one I tend to use.

E.g.

import matplotlib.pyplot as plt
import numpy as np

# Make a random plot...
fig = plt.figure()
fig.add_subplot(111)

# If we haven't already shown or saved the plot, then we need to
# draw the figure first...
fig.canvas.draw()

# Now we can save it to a numpy array.
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
Sign up to request clarification or add additional context in comments.

14 Comments

Works on Agg, add matplotlib.use('agg') before import matplotlib.pyplot as plt to use it.
With images, the canvas adds a big margin, so I found it useful to insert fig.tight_layout(pad=0) before drawing.
For figures with lines and text, it can also be important to turn antialiasing off. For lines plt.setp([ax.get_xticklines() + ax.get_yticklines() + ax.get_xgridlines() + ax.get_ygridlines()],antialiased=False) and for text mpl.rcParams['text.antialiased']=False
@JoeKington np.fromstring with sep='' is deprecated since version 1.14. It should be replaced with data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) in future versions
In case you run into 'FigureCanvasGTKAgg' object has no attribute 'renderer', remember to matplotlib.use('Agg'): stackoverflow.com/a/35407794/5339857
|
56

There is a bit simpler option for @JUN_NETWORKS's answer. Instead of saving the figure in png, one can use other format, like raw or rgba and skip the cv2 decoding step.

In other words the actual plot-to-numpy conversion boils down to:

io_buf = io.BytesIO()
fig.savefig(io_buf, format='raw', dpi=DPI)
io_buf.seek(0)
img_arr = np.reshape(np.frombuffer(io_buf.getvalue(), dtype=np.uint8),
                     newshape=(int(fig.bbox.bounds[3]), int(fig.bbox.bounds[2]), -1))
io_buf.close()

Hope, this helps.

8 Comments

I think this answer is far superior to the ones above: 1) It produces high-res images and 2) doesn't rely on external packages like cv2.
I get a reshape error "cannot reshape array of size 3981312 into shape (480,640,newaxis)". Any ideas?
Indeed this answer is exactly what I was looking for ! Thank you !
@FabianHertwig I have the same problem, and here is the fix. When you create the fig you have to set dpi to be the same when you saved it fig = plt.figure(figsize=(16, 4), dpi=128) then fig.savefig(io_buf, format='raw', dpi=128)
This worked for me when omitting the dpi parameter, i.e. fig.savefig(io_buf, format='raw')
|
23

Some people propose a method which is like this

np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')

Ofcourse, this code work. But, output numpy array image is so low resolution.

My proposal code is this.

import io
import cv2
import numpy as np
import matplotlib.pyplot as plt

# plot sin wave
fig = plt.figure()
ax = fig.add_subplot(111)

x = np.linspace(-np.pi, np.pi)

ax.set_xlim(-np.pi, np.pi)
ax.set_xlabel("x")
ax.set_ylabel("y")

ax.plot(x, np.sin(x), label="sin")

ax.legend()
ax.set_title("sin(x)")


# define a function which returns an image as numpy array from figure
def get_img_from_fig(fig, dpi=180):
    buf = io.BytesIO()
    fig.savefig(buf, format="png", dpi=dpi)
    buf.seek(0)
    img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8)
    buf.close()
    img = cv2.imdecode(img_arr, 1)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    return img

# you can get a high-resolution image as numpy array!!
plot_img_np = get_img_from_fig(fig)

This code works well.
You can get a high-resolution image as a numpy array if you set a large number on the dpi argument.

3 Comments

I suggest adding the import statements along with the function.
@AnshulRai Thanks for your great advice!! I've added code about import, plot, and how to use the function.
Hey, I was wondering if there's a way I could get rid of x and y axis labels and just get the plot?
23

Time to benchmark your solutions.

import io
import matplotlib
matplotlib.use('agg')  # turn off interactive backend
import matplotlib.pyplot as plt
import numpy as np

fig, ax = plt.subplots()
ax.plot(range(10))


def plot1():
    fig.canvas.draw()
    data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
    w, h = fig.canvas.get_width_height()
    im = data.reshape((int(h), int(w), -1))


def plot2():
    with io.BytesIO() as buff:
        fig.savefig(buff, format='png')
        buff.seek(0)
        im = plt.imread(buff)


def plot3():
    with io.BytesIO() as buff:
        fig.savefig(buff, format='raw')
        buff.seek(0)
        data = np.frombuffer(buff.getvalue(), dtype=np.uint8)
    w, h = fig.canvas.get_width_height()
    im = data.reshape((int(h), int(w), -1))
>>> %timeit plot1()
34 ms ± 4.16 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
>>> %timeit plot2()
50.2 ms ± 234 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
>>> %timeit plot3()
16.4 ms ± 36 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Under this scenario, IO raw buffers are the fastest to convert a matplotlib figure to a numpy array.

Additional remarks:

  • if you don't have an access to the figure, you can always extract it from the axes:

    fig = ax.figure

  • if you need the array in the channel x height x width format, do

    im = im.transpose((2, 0, 1)).

4 Comments

Why do so many stackoverflowers obsess about benchmarks?
@Epimetheus, isn't it self-explanatory? the world only cares about speed and results' quality, a fast method can save companies.
Is this not missing a fig.canvas.draw() in the other two functions?
@darda no, it's not needed in plot2 and plot3. Do you see an empty picture?
7

In case somebody wants a plug and play solution, without modifying any prior code (getting the reference to pyplot figure and all), the below worked for me. Just add this after all pyplot statements i.e. just before pyplot.show()

canvas = pyplot.gca().figure.canvas
canvas.draw()
data = numpy.frombuffer(canvas.tostring_rgb(), dtype=numpy.uint8)
image = data.reshape(canvas.get_width_height()[::-1] + (3,))

Comments

7

MoviePy makes converting a figure to a numpy array quite simple. It has a built-in function for this called mplfig_to_npimage(). You can use it like this:

from moviepy.video.io.bindings import mplfig_to_npimage
import matplotlib.pyplot as plt

fig = plt.figure()  # make a figure
numpy_fig = mplfig_to_npimage(fig)  # convert it to a numpy array

3 Comments

As the shortest working solution Daniel Giger should get more upvotes.
But it requires installing another package :(
Src is short: github.com/Zulko/moviepy/blob/master/moviepy/video/io/… and similar to the answer by @joe-kington.
4

As Joe Kington has pointed out, one way is to draw on the canvas, convert the canvas to a byte string and then reshape it into the correct shape.

import matplotlib.pyplot as plt
import numpy as np
import math

plt.switch_backend('Agg')


def canvas2rgb_array(canvas):
    """Adapted from: https://stackoverflow.com/a/21940031/959926"""
    canvas.draw()
    buf = np.frombuffer(canvas.tostring_rgb(), dtype=np.uint8)
    ncols, nrows = canvas.get_width_height()
    scale = round(math.sqrt(buf.size / 3 / nrows / ncols))
    return buf.reshape(scale * nrows, scale * ncols, 3)


# Make a simple plot to test with
t = np.arange(0.0, 2.0, 0.01)
s = 1 + np.sin(2 * np.pi * t)
fig, ax = plt.subplots()
ax.plot(t, s)

# Extract the plot as an array
plt_array = canvas2rgb_array(fig.canvas)
print(plt_array.shape)

However as canvas.get_width_height() returns width and height in display coordinates, there are sometimes scaling issues that are resolved in this answer.

Comments

2

Cleaned up version of the answer by Jonan Gueorguiev:

with io.BytesIO() as io_buf:
  fig.savefig(io_buf, format='raw', dpi=dpi)
  image = np.frombuffer(io_buf.getvalue(), np.uint8).reshape(
      int(fig.bbox.bounds[3]), int(fig.bbox.bounds[2]), -1)

Comments

0
import numpy as np 
import cv2
import time
import justpyplot as jplt

xs, ys = [], []
while(cv2.waitKey(1) != 27):
    xt = time.perf_counter() - t0
    yx = np.sin(xt)
    xs.append(xt)
    ys.append(yx)
    
    frame = np.full((500,470,3), (255,255,255), dtype=np.uint8)
    
    vals = np.array(ys)

    plotted_in_array = jplt.just_plot(frame, vals,title="sin() from Clock")
    
    cv2.imshow('np array plot', plotted_in_array)

The issues with all the matplotlib approaches is that matplotlib can still render and display plot even if you do plt.ioff() or return the figure and even if you do succeed while it behaves differently on a different platform(because matplotlib delegates it to backend depending on os) - you get a performance hit for getting plotted numpy array. I measured all previosly suggested matplotlib approaches and it rakes in milliseconds, most often dozens, sometimes even more milliseconds.

I couldn't find a simple library that just does it, had to write the thing myself. A plot to numpy in fully vectorized numpy(not a single loop) for all the parts such as scatter, connected, axis, grid, including size of the points and thickness and it does it in microseconds

https://github.com/bedbad/justpyplot

Comments

0

This solution does not require any other libraries than matplotlib, numpy or PIL Image:

CanvasAgg Demo

Using this, one can also have matplotlib plot directly on tkinter widgets (using ImageTk).

from PIL import Image, ImageTk

import numpy as np

import tkinter as tk


from matplotlib.backends.backend_agg import FigureCanvasAgg
from matplotlib.figure import Figure

fig = Figure(figsize=(5, 4))
canvas = FigureCanvasAgg(fig)

ax = fig.add_subplot()
line = ax.plot([1, 2, 3])

canvas.draw()
rgba = np.asarray(canvas.buffer_rgba())
im = Image.fromarray(rgba)


root = tk.Tk()
tim = ImageTk.PhotoImage(im)
im_lb = tk.Label(image = tim)

im_lb.pack()

tk.mainloop()

This unlocks the potential of having interactive matplotlib plots via tkinter!

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.