import numpy as np
from matplotlib import pyplot as plt
from matplotlib.cm import ScalarMappable
from matplotlib.figure import SubplotParams
import os
dpi = 300
ax_pad = 10
label_fontdict = {'size': 7}
title_fontdict = {'size': 10}
cbar_inset = [1.02, 0, .0125, .96]
cbar_titel_fontdict = {'size': 7}
cbar_labels_fontdict = {'size': 7}
root_fig = plt.rcParams['savefig.directory']
[docs]def make_cbar(ax, title, norm, cmap, cmap_labels, im=None, prefix_labels=True):
    """Generate a colorbar for the given axes.
    Parameters
    ----------
    ax: Axes
        axes for which to plot colorbar
    title: str
        title of colorbar
    norm:
        Normalisation instance
    cmap: Colormap
        Colormap
    cmap_labels: dict
        colorbar labels
    Returns
    -------
    """
    # NOTE: The Linercolormap ticks can only be set up to the number of colors. Thus if we do not have linear, sequential
    # values [0,1,2,3] in the cmap_labels dict this will fail. Solution could be to remap.
    inset = ax.inset_axes(cbar_inset)
    fig = ax.get_figure()
    if im is None:
        cb = fig.colorbar(ScalarMappable(norm=norm, cmap=cmap), cax=inset)
    else:
        cb = fig.colorbar(im, cax=inset)
    cb.ax.set_title(title, loc='left', fontdict=cbar_titel_fontdict)
    if cmap_labels:
        if prefix_labels:
            labs = [f'{key}, {val}' for key, val in cmap_labels.items()]
        else:
            labs = list(cmap_labels.values())
        cb.set_ticks(list(cmap_labels.keys()))
        cb.ax.set_yticklabels(labs, fontdict=cbar_labels_fontdict)
    else:
        cb.ax.tick_params(axis='y', labelsize=cbar_labels_fontdict['size'])
    # TODO
    def linear_mapping(cmap_labels):
        pass 
[docs]def savefig(fig, save):
    # if only filename is given, add root_fig, convenient to save plots less verbose.
    if save == os.path.basename(save):
        save = os.path.join(plt.rcParams['savefig.directory'], save)
    fig.savefig(save)
    print(f'Figure saved at: {save}') 
[docs]def colormap(cmap):
    """Visualise a colormap.
    Parameters
    ----------
    cmap: Colormap
    Returns
    -------
    """
    n = len(cmap.colors)
    a = np.zeros((1, n, 4))
    for j in range(n):
        a[0, j,] = cmap(j)
    fig, ax = plt.subplots(1, 1)
    ax.imshow(a)
    ax.tick_params(labelleft=False, left=False)
    fig.show() 
    # ax.set_aspect(.25)
[docs]def get_layout(nx, ny=None, max_col=5, max_row=None):
    if ny is None:
        ny = 0
    ncol = np.min((np.ceil(np.sqrt(nx)), max_col)).astype(int)
    nrow = int(np.ceil((nx + ny) / ncol))
    return (nrow, ncol)