This function is working exactly as I want, only, it's taking too long.
For speed ups, I've tried to do as much as I can before the main for
loop by declaring each function as a local variable, I've also switched from using pandas dataframes to numpy arrays and decreased the outputted dpi.
This function is being fed large amounts of data so any speed improvement suggestions will be much appreciated. I don't know any Cython (or C) but would be willing to learn some if it was going to dramatically improve performance. I also welcome any suggestions on how I could improve the style of my code.
import os
import logging
import traceback
import warnings
from itertools import chain
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
NO_GROUPING_NAME = 'NoGrouping'
plt.style.use('ggplot')
def cdf_plot(total_param_values):
"""
Given a 3-deep nested dictionary, cdf_plot saves a cumulative frequency
distribution plot out of the values of each inner-most dictionary. This will
be a scatter plot with colours corresponding to the keys of the dict being plotted.
If a 2-deep nested dictionary has key == NO_GROUPING_NAME then the corresponding
value will be a dictionary of only one key with value of one list of floats, so
the plot will only have one colour. In this case, no legend is drawn.
The cumulative frequency distribution data is formed from a list of values
(call the list all_x_values) by plotting the sorted values on the x-axis.
A corresponding y-value (for a given x-value) is equal to norm.ppf(i/len(all_x_values))
where i is the index of the given x-value in all_x_values and norm.ppf is a
function from scipy.stats (Percent point function (inverse of cdf — percentiles)).
Parameters
----------
total_param_values : { string : { string : { string : list of floats}}}
This corresponds to {p_id : {grouping : {group_instance : values}}}
"""
# Do as much as possible before loop
fig = plt.figure()
add_subplot = fig.add_subplot
textremove = fig.texts.remove
xlabel = plt.xlabel
ylabel = plt.ylabel
yticks = plt.yticks
cla = plt.cla
savefig = plt.savefig
figtext = plt.figtext
currentfigtext = None
colours = ('b', 'g', 'r', 'c','teal', 'm','papayawhip', 'y', 'k',
'aliceblue', 'aqua', 'forestgreen', 'deeppink', 'blanchedalmond',
'burlywood', 'darkgoldenrod')
nparray = np.array
nanstd = np.nanstd
nanmean = np.nanmean
npsort = np.sort
isnan = np.isnan
vectorize = np.vectorize
normppf = norm.ppf
chainfrom_iterable = chain.from_iterable
# Prepare yticks
y_labels = [0.0001, 0.001, 0.01, 0.10, 0.25, 0.5,
0.75, 0.90, 0.99, 0.999, 0.9999]
y_pos = [normppf(i) for i in y_labels]
try:
# Hide annoying warning
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=FutureWarning)
for p_id, p_id_dict in total_param_values.items():
for grouping, grouping_dict in p_id_dict.items():
#check whether plot already exists
save_name = p_id + grouping + '.png'
if os.path.exists(save_name):
continue
# Keep count of position in colour cycle
colour_count = 0
ax = add_subplot(111)
axscatter = ax.scatter
# Work out normalising function
chn = chainfrom_iterable(grouping_dict.values())
flattened = list(chn)
std = nanstd(flattened)
mean = nanmean(flattened)
if std:
two_ops = lambda x: (x - mean) / std
v_norm = vectorize(two_ops)
else:
one_op = lambda x: x - mean
v_norm = vectorize(one_op)
# Keep track of total number of values plotted this iteration
total_length = 0
for group_instance, values in grouping_dict.items():
values = nparray(values)
values = npsort(values[~isnan(values)])
length = len(values)
total_length += length
# Skip graphing any empty array
if not length:
continue
# Normalise values to be ready for plotting on x-axis
values = v_norm(values)
# Prepare y-values as described in function doc
y = [normppf(i/length) for i in range(length)]
axscatter(values, y, color=colours[colour_count % len(colours)],
label=group_instance + ' (' + str(length) + ')',
alpha=0.6)
colour_count += 1
# If no values were found, clear axis and skip to next iteration
if not total_length:
cla()
continue
if grouping != NO_GROUPING_NAME:
try:
ax.legend(loc='lower right', title=grouping + ' ('
+ str(total_length) + ')', fontsize=6,
scatterpoints=1)
except ValueError:
print('EXCEPTION: Weird error with legend() plotting,\
something about taking the max of an empty sequence')
pass
else:
# Turn off legend but display total_length in bottom right corner
ax.legend().set_visible(False)
if currentfigtext is not None:
textremove(currentfigtext)
currentfigtext = figtext(0.99, 0.01, 'Number of points = '
+ str(total_length),
horizontalalignment='right')
xlabel('')
ylabel('')
yticks(y_pos, y_labels)
savefig(save_name, dpi=60)
cla()
except Exception as e:
print('It broke.............', e)
print('Variable dump........')
print('grouping {}, group_instance {}, values {}, length {}, ax {}, y {},\
colour_count {}, figtext {} std {}, mean {},\
save_name {}'.format(grouping, group_instance, values, length,
ax, y, colour_count, currentfigtext, std, mean, save_name))
logging.error(traceback.format_exc())
raise
# Make sure no figures remain
plt.close('all')
# In an attempt to make the nesting clear I've written test_values out in a
# weird way (weird to me at least)
test_values = {
'p_1' : {
'NoGrouping' : {
'' : list(np.random.rand(100))
},
'Sky' : {
'Blue' : list(np.random.rand(100)),
'Red' : list(np.random.rand(100)),
'Grey' : list(np.random.rand(100))
}
},
'p_2' : {
'NoGrouping' : {
'' : list(np.random.rand(100))
},
'Sky' : {
'Blue' : list(np.random.rand(100)),
'Red' : list(np.random.rand(100)),
'Grey' : list(np.random.rand(100))
}
}
}
cdf_plot(test_values)