Code Review Stack Exchange is a question and answer site for peer programmer code reviews. Join them; it only takes a minute:

Sign up
Here's how it works:
  1. Anybody can ask a question
  2. Anybody can answer
  3. The best answers are voted up and rise to the top

This function is working exactly as I want, only, it's taking too long.

For speed ups, I've tried to do as much as I can before the main for loop by declaring each function as a local variable, I've also switched from using pandas dataframes to numpy arrays and decreased the outputted dpi.

This function is being fed large amounts of data so any speed improvement suggestions will be much appreciated. I don't know any Cython (or C) but would be willing to learn some if it was going to dramatically improve performance. I also welcome any suggestions on how I could improve the style of my code.

import os
import logging
import traceback
import warnings
from itertools import chain

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm

NO_GROUPING_NAME = 'NoGrouping'
plt.style.use('ggplot')


def cdf_plot(total_param_values):
    """
    Given a 3-deep nested dictionary, cdf_plot saves a cumulative frequency 
    distribution plot out of the values of each inner-most dictionary. This will
    be a scatter plot with colours corresponding to the keys of the dict being plotted. 

    If a 2-deep nested dictionary has key == NO_GROUPING_NAME then the corresponding
    value will be a dictionary of only one key with value of one list of floats, so 
    the plot will only have one colour. In this case, no legend is drawn.

    The cumulative frequency distribution data is formed from a list of values
    (call the list all_x_values) by plotting the sorted values on the x-axis. 
    A corresponding y-value (for a given x-value) is equal to norm.ppf(i/len(all_x_values))
    where i is the index of the given x-value in all_x_values and norm.ppf is a
    function from scipy.stats (Percent point function (inverse of cdf — percentiles)).

    Parameters
    ----------
    total_param_values : { string : { string : { string : list of floats}}}
        This corresponds to  {p_id : {grouping : {group_instance : values}}}
    """

    # Do as much as possible before loop
    fig = plt.figure()
    add_subplot = fig.add_subplot
    textremove = fig.texts.remove
    xlabel = plt.xlabel
    ylabel = plt.ylabel
    yticks = plt.yticks
    cla = plt.cla
    savefig = plt.savefig
    figtext = plt.figtext
    currentfigtext = None
    colours = ('b', 'g', 'r', 'c','teal', 'm','papayawhip', 'y', 'k', 
               'aliceblue', 'aqua', 'forestgreen', 'deeppink', 'blanchedalmond',
               'burlywood', 'darkgoldenrod') 

    nparray = np.array
    nanstd = np.nanstd
    nanmean = np.nanmean
    npsort = np.sort
    isnan = np.isnan
    vectorize = np.vectorize

    normppf = norm.ppf
    chainfrom_iterable = chain.from_iterable

    # Prepare yticks
    y_labels = [0.0001, 0.001, 0.01, 0.10, 0.25, 0.5,
            0.75, 0.90, 0.99, 0.999, 0.9999]
    y_pos = [normppf(i) for i in y_labels]

    try:
        # Hide annoying warning
        with warnings.catch_warnings():
            warnings.filterwarnings('ignore', category=FutureWarning)

            for p_id, p_id_dict in total_param_values.items():
                for grouping, grouping_dict in p_id_dict.items():

                    #check whether plot already exists
                    save_name = p_id + grouping + '.png'
                    if os.path.exists(save_name):
                        continue

                    # Keep count of position in colour cycle
                    colour_count = 0

                    ax = add_subplot(111)
                    axscatter = ax.scatter

                    # Work out normalising function
                    chn = chainfrom_iterable(grouping_dict.values())
                    flattened = list(chn)
                    std = nanstd(flattened)
                    mean = nanmean(flattened)
                    if std:
                        two_ops = lambda x: (x - mean) / std
                        v_norm = vectorize(two_ops)
                    else:       
                        one_op = lambda x: x - mean
                        v_norm = vectorize(one_op)

                    # Keep track of total number of values plotted this iteration
                    total_length = 0    

                    for group_instance, values in grouping_dict.items():
                        values = nparray(values)
                        values = npsort(values[~isnan(values)])
                        length = len(values)                        
                        total_length += length

                        # Skip graphing any empty array
                        if not length:
                            continue

                        # Normalise values to be ready for plotting on x-axis
                        values = v_norm(values)

                        # Prepare y-values as described in function doc
                        y = [normppf(i/length) for i in range(length)]

                        axscatter(values, y, color=colours[colour_count % len(colours)],
                                label=group_instance + '   (' + str(length) + ')',
                                 alpha=0.6)
                        colour_count += 1

                    # If no values were found, clear axis and skip to next iteration
                    if not total_length:
                        cla()
                        continue

                    if grouping != NO_GROUPING_NAME:
                        try:
                            ax.legend(loc='lower right', title=grouping + '  ('
                                        + str(total_length) + ')', fontsize=6,
                                         scatterpoints=1)
                        except ValueError:
                            print('EXCEPTION: Weird error with legend() plotting,\
                                    something about taking the max of an empty sequence')
                            pass
                    else:
                        # Turn off legend but display total_length in bottom right corner
                        ax.legend().set_visible(False)
                        if currentfigtext is not None:
                            textremove(currentfigtext)
                        currentfigtext = figtext(0.99, 0.01, 'Number of points = ' 
                                                + str(total_length),
                                                 horizontalalignment='right')
                    xlabel('')
                    ylabel('')
                    yticks(y_pos, y_labels)
                    savefig(save_name, dpi=60)
                    cla()

    except Exception as e:
        print('It broke.............', e)
        print('Variable dump........')
        print('grouping {}, group_instance {}, values {}, length {}, ax {}, y {},\
                colour_count {}, figtext {} std {}, mean {},\
                save_name {}'.format(grouping, group_instance, values, length,
                 ax, y, colour_count, currentfigtext, std, mean, save_name))
        logging.error(traceback.format_exc())
        raise

    # Make sure no figures remain
    plt.close('all')


# In an attempt to make the nesting clear I've written test_values out in a 
# weird way (weird to me at least) 
test_values = {
                'p_1' : { 
                        'NoGrouping' : {
                                    '' : list(np.random.rand(100))
                                       },
                         'Sky' : { 
                                    'Blue' : list(np.random.rand(100)),
                                    'Red' : list(np.random.rand(100)),
                                    'Grey' : list(np.random.rand(100))
                                 } 
                       },
                'p_2' : { 
                        'NoGrouping' : {
                                    '' : list(np.random.rand(100))
                                       },
                        'Sky' : { 
                                    'Blue' : list(np.random.rand(100)),
                                    'Red' : list(np.random.rand(100)),
                                    'Grey' : list(np.random.rand(100))
                                } 
                       }
} 

cdf_plot(test_values)
share|improve this question
    
Have you run line profiling on your code? – Curt F. Jan 13 at 16:19
up vote 1 down vote accepted

Your function is too large to digest in one sitting, and without running on my machine. So I'll start with a few quick observations.

The idea of putting a plot function (axscatter) in an inner most loop just feels wrong. It makes more sense to build arrays of values, and call the plot just once per subplot. But if the number of values that you plot per call is large, and the number of iterations is small, that isn't so big of a deal. It's hard to know just how many iterations this code performs.

The use of vectorize on functions like lambda x: (x - mean) / std is unnecessary. vectorize does not speed up code; it just makes it easier to apply scalar operations to arrays. x-mean is already a numpy array operation, as is ()/std.

The use of list comprehension y = [normppf(i/length) for i in range(length)] is, I suspect, unnecessary. But I don't know exactly what normppf does. Does it accept an array, or just scalar values?

Organizing this as one big function makes it hard to test and optimize pieces. I'd put each level of iteration in a separate function. I should be able to test the data manipulation functions without plotting, and be able to plot without generating a new set of data.

I like to see, at a glance, where the code is massaging the data, and where it is plotting. And with plotting it is nice to separate out the actions that actually plot from those that tweak the appearance. I also like to clearly identify when it's calling imported code. You do that with np.... but not other functions.

share|improve this answer
    
Thanks, your suggestions are great and I will be using them all. normppf does indeed accept an array so the list comprehension is redundant. – Ben Carr Nov 30 '15 at 9:20

Your Answer

 
discard

By posting your answer, you agree to the privacy policy and terms of service.

Not the answer you're looking for? Browse other questions tagged or ask your own question.