Python seaborn 模块,set_context() 实例源码

我们从Python开源项目中,提取了以下30个代码示例,用于说明如何使用seaborn.set_context()

项目:word2vec_pipeline    作者:NIHOPA    | 项目源码 | 文件源码
def plot_heatmap():

    data = load_dispersion_data()
    linkage = data["linkage"]

    sns.set_context("notebook", font_scale=1.25)
    p = sns.clustermap(data=data["dispersion"],
                       row_linkage=linkage,
                       col_linkage=linkage,
                       vmin=0.50,
                       vmax=1.00,
                       cmap=cmap_clustermap,
                       figsize=(12, 10))

    labels = p.data2d.columns

    # Sanity check, make sure the plotted dendrogram matches the saved values
    assert((labels == data["dendrogram_order"]).all())
项目:activity-browser    作者:LCA-ActivityBrowser    | 项目源码 | 文件源码
def __init__(self, parent, mlca, width=6, height=6, dpi=100):
        figure = Figure(figsize=(width, height), dpi=dpi, tight_layout=True)
        axes = figure.add_subplot(111)

        super(LCAResultsPlot, self).__init__(figure)
        self.setParent(parent)
        activity_names = [format_activity_label(next(iter(f.keys()))) for f in mlca.func_units]
        # From https://stanford.edu/~mwaskom/software/seaborn/tutorial/color_palettes.html
        cmap = sns.cubehelix_palette(8, start=.5, rot=-.75, as_cmap=True)
        hm = sns.heatmap(
            # mlca.results / np.average(mlca.results, axis=0), # Normalize to get relative results
            mlca.results,
            annot=True,
            linewidths=.05,
            cmap=cmap,
            xticklabels=["\n".join(x) for x in mlca.methods],
            yticklabels=activity_names,
            ax=axes,
            square=False,
        )
        hm.tick_params(labelsize=8)

        self.setMinimumSize(self.size())
        # sns.set_context("notebook")
项目:double-dqn    作者:musyoku    | 项目源码 | 文件源码
def plot_evaluation_episode_reward():
    pylab.clf()
    sns.set_context("poster")
    pylab.plot(0, 0)
    episodes = [0]
    average_scores = [0]
    median_scores = [0]
    for n in xrange(len(csv_evaluation)):
        params = csv_evaluation[n]
        episodes.append(params[0])
        average_scores.append(params[1])
        median_scores.append(params[2])
    pylab.plot(episodes, average_scores, sns.xkcd_rgb["windows blue"], lw=2)
    pylab.xlabel("episodes")
    pylab.ylabel("average score")
    pylab.savefig("%s/evaluation_episode_average_reward.png" % args.plot_dir)

    pylab.clf()
    pylab.plot(0, 0)
    pylab.plot(episodes, median_scores, sns.xkcd_rgb["windows blue"], lw=2)
    pylab.xlabel("episodes")
    pylab.ylabel("median score")
    pylab.savefig("%s/evaluation_episode_median_reward.png" % args.plot_dir)
项目:DAB_analyzer    作者:meklon    | 项目源码 | 文件源码
def plot_group(data_frame, path_output):
    # optional import
    import seaborn as sns
    path_output_image = os.path.join(path_output, "summary_statistics.png")

    # # Plotting swarmplot
    # plt.figure(num=None, figsize=(15, 7), dpi=120)
    # sns.set_style("whitegrid")
    #
    # plt.title('Violin plot with single measurements')
    # sns.violinplot(x="Group", y="DAB+ area", data=data_frame, inner=None)
    # sns.swarmplot(x="Group", y="DAB+ area", data=data_frame, color="w", alpha=.5)
    # plt.savefig(path_output_image)
    #
    # plt.tight_layout()

    sns.set_style("whitegrid")
    sns.set_context("talk")
    plt.figure(num=None, figsize=(15, 7), dpi=120)
    plt.ylim(0, 100)
    plt.title('Box plot')
    sns.boxplot(x="Group", y="DAB+ area, %", data=data_frame)

    plt.tight_layout()
    plt.savefig(path_output_image, dpi=300)
项目:syracuse_public    作者:dssg    | 项目源码 | 文件源码
def plot_predict_proba(y_pred_probs, clf, pdf=None):
    """Plots the predict proba distribution"""
    fig, ax = plt.subplots(1, figsize=(18, 8))
    sns.set_style("white")
    sns.set_context("poster",
                    font_scale=2.25,
                    rc={"lines.linewidth": 1.25, "lines.markersize": 8})
    sns.distplot(y_pred_probs)
    plt.xlabel('predict_proba')
    plt.ylabel('frequency')
    plt.title(clf + ' proba')
    if pdf:
        pdf.savefig()
        plt.close()
    else:
        plt.show()
项目:astetik    作者:mikkokotila    | 项目源码 | 文件源码
def swarm(data,x,y,xscale='linear',yscale='linear'):

    # set default pretty settings from Seaborn

    sns.set(style="white", palette="muted")
    sns.set_context("notebook", font_scale=1, rc={"lines.linewidth": 0.2}) 

    # createthe plot

    g = sns.swarmplot(x=x, y=y, data=data, palette='RdYlGn')

    plt.tick_params(axis='both', which='major', pad=10)

    g.set(xscale=xscale)
    g.set(yscale=yscale)

    # Setting plot limits

    start = data[y].min().min()
    plt.ylim(start,);

    sns.despine()
项目:astetik    作者:mikkokotila    | 项目源码 | 文件源码
def histogram(data,variables):

    sns.set_context("notebook", font_scale=1.5, rc={"lines.linewidth": 0})

    sns.set_style('white')

    var_length = len(variables)

    fig, axes = plt.subplots(1, var_length, figsize=(19, 5))

    for i in range(var_length):

        axes[i].hist(data[variables[i]],lw=0,color="indianred",bins=8);
        axes[i].tick_params(axis='both', which='major', pad=15)
        axes[i].set_xlabel(variables[i])
        axes[i].set_yticklabels("");

    sns.despine(left=True)
项目:astetik    作者:mikkokotila    | 项目源码 | 文件源码
def correlation(data,title=''):

    corr = data.corr(method='spearman')
    mask = np.zeros_like(corr)
    mask[np.triu_indices_from(mask)] = True

    sns.set(style="white")
    sns.set_context("notebook", font_scale=2, rc={"lines.linewidth": 0.3})

    rcParams['figure.figsize'] = 25, 12
    rcParams['font.family'] = 'Verdana'
    rcParams['figure.dpi'] = 300

    g = sns.heatmap(corr, mask=mask, linewidths=1, cmap="RdYlGn", annot=False)
    g.set_xticklabels(data,rotation=25,ha="right");
    plt.tick_params(axis='both', which='major', pad=15);
项目:dueling-network    作者:musyoku    | 项目源码 | 文件源码
def plot_evaluation_episode_reward():
    pylab.clf()
    sns.set_context("poster")
    pylab.plot(0, 0)
    episodes = [0]
    average_scores = [0]
    median_scores = [0]
    for n in xrange(len(csv_evaluation)):
        params = csv_evaluation[n]
        episodes.append(params[0])
        average_scores.append(params[1])
        median_scores.append(params[2])
    pylab.plot(episodes, average_scores, sns.xkcd_rgb["windows blue"], lw=2)
    pylab.xlabel("episodes")
    pylab.ylabel("average score")
    pylab.savefig("%s/evaluation_episode_average_reward.png" % args.plot_dir)

    pylab.clf()
    pylab.plot(0, 0)
    pylab.plot(episodes, median_scores, sns.xkcd_rgb["windows blue"], lw=2)
    pylab.xlabel("episodes")
    pylab.ylabel("median score")
    pylab.savefig("%s/evaluation_episode_median_reward.png" % args.plot_dir)
项目:Iris-Classification-with-Heroku    作者:gaborvecsei    | 项目源码 | 文件源码
def plotPrediction(pred):
    """
    Plots the prediction than encodes it to base64
    :param pred: prediction accuracies
    :return: base64 encoded image as string
    """

    labels = ['setosa', 'versicolor', 'virginica']
    sns.set_context(rc={"figure.figsize": (5, 5)})
    with sns.color_palette("RdBu_r", 3):
        ax = sns.barplot(x=labels, y=pred)
    ax.set(ylim=(0, 1))

    # Base64 encode the plot
    stringIObytes = cStringIO.StringIO()
    sns.plt.savefig(stringIObytes, format='jpg')
    sns.plt.show()
    stringIObytes.seek(0)
    base64data = base64.b64encode(stringIObytes.read())
    return base64data
项目:pymoku    作者:liquidinstruments    | 项目源码 | 文件源码
def phase1_plot_setup():
    # Set up a 1x2 plot
    f, (ax1, ax2) = plt.subplots(1,2)
    f.suptitle('Phase 1 - Rise Times', fontsize=18, fontweight='bold')

    # Choose a colour palette and font size/style
    colours = sns.color_palette("muted")
    sns.set_context('poster')

    # Maximise the plotting window
    plot_backend = matplotlib.get_backend()
    mng = plt.get_current_fig_manager()
    if plot_backend == 'TkAgg':
        mng.resize(*mng.window.maxsize())
    elif plot_backend == 'wxAgg':
        mng.frame.Maximize(True)
    elif plot_backend == 'Qt4Agg':
        mng.window.showMaximized()

    return f, ax1, ax2
项目:pymoku    作者:liquidinstruments    | 项目源码 | 文件源码
def phase2_plot_setup():
    # Set up a 1x1 plot
    f, ax1 = plt.subplots(1,1)
    f.suptitle('Phase 2 - Line Width', fontsize=18, fontweight='bold')

    # Choose a colour palette and font size/style
    colours = sns.color_palette("muted")
    sns.set_context('poster')

    # Maximise the plotting window
    plot_backend = matplotlib.get_backend()
    mng = plt.get_current_fig_manager()
    if plot_backend == 'TkAgg':
        mng.resize(*mng.window.maxsize())
    elif plot_backend == 'wxAgg':
        mng.frame.Maximize(True)
    elif plot_backend == 'Qt4Agg':
        mng.window.showMaximized()

    return f, ax1
项目:fitbit-analyzer    作者:5agado    | 项目源码 | 文件源码
def plotSleepValueHeatmap(intradayStats, sleepValue=1):
    sns.set_context("poster")
    sns.set_style("darkgrid")

    xTicksDiv = 20
    #stepSize = int(len(xticks)/xTicksDiv)
    stepSize = 60
    xticks = [x for x in intradayStats.columns.values]
    keptticks = xticks[::stepSize]
    xticks = ['' for _ in xticks]
    xticks[::stepSize] = keptticks
    plt.figure(figsize=(16, 4.2))
    g = sns.heatmap(intradayStats.loc[sleepValue].reshape(1,-1))
    g.set_xticklabels(xticks, rotation=45)
    g.set_yticklabels([])
    g.set_ylabel(sleepStats.SLEEP_VALUES[sleepValue])
    plt.tight_layout()
    sns.plt.show()
项目:saapy    作者:ashapochka    | 项目源码 | 文件源码
def style_matplotlib_for_notebook(self):
        sns.set_context("notebook",
                        rc=self.cfg['matplotlib']['notebook'])
项目:sci-pype    作者:jay-johnson    | 项目源码 | 文件源码
def sb_initialize_fonts(self, font_size=12, title_size=12, label_size=10):
        import seaborn as sns
        sns.set_context("paper", rc={"font.size":font_size,"axes.titlesize":title_size,"axes.labelsize":label_size})
    # end of sb_initialize_fonts
项目:double-dqn    作者:musyoku    | 项目源码 | 文件源码
def plot_episode_reward():
    pylab.clf()
    sns.set_context("poster")
    pylab.plot(0, 0)
    episodes = [0]
    scores = [0]
    for n in xrange(len(csv_episode)):
        params = csv_episode[n]
        episodes.append(params[0])
        scores.append(params[1])
    pylab.plot(episodes, scores, sns.xkcd_rgb["windows blue"], lw=2)
    pylab.xlabel("episodes")
    pylab.ylabel("score")
    pylab.savefig("%s/episode_reward.png" % args.plot_dir)
项目:double-dqn    作者:musyoku    | 项目源码 | 文件源码
def plot_training_episode_highscore():
    pylab.clf()
    sns.set_context("poster")
    pylab.plot(0, 0)
    episodes = [0]
    highscore = [0]
    for n in xrange(len(csv_training_highscore)):
        params = csv_training_highscore[n]
        episodes.append(params[0])
        highscore.append(params[1])
    pylab.plot(episodes, highscore, sns.xkcd_rgb["windows blue"], lw=2)
    pylab.xlabel("episodes")
    pylab.ylabel("highscore")
    pylab.savefig("%s/training_episode_highscore.png" % args.plot_dir)
项目:datanode    作者:jay-johnson    | 项目源码 | 文件源码
def sb_initialize_fonts(self, font_size=12, title_size=12, label_size=10):
        import seaborn as sns
        sns.set_context("paper", rc={"font.size":font_size,"axes.titlesize":title_size,"axes.labelsize":label_size})
    # end of sb_initialize_fonts
项目:icing    作者:slipguru    | 项目源码 | 文件源码
def plot_learning_function(xdata, ydata, yerr, order, aplot, poly):
    with sns.axes_style('whitegrid'):
        sns.set_context('paper')
        xp = np.linspace(np.min(xdata), np.max(xdata), 1000)[:, None]
        plt.figure()
        plt.errorbar(xdata, ydata, yerr,
                     label='Nearest similarity', marker='s')
        plt.plot(xp, poly(xp), '-',
                 label='Learning function (poly of order {})'.format(order))
        # plt.plot(xp, least_squares_mdl(res.x, xp), '-', label='least squares')
        plt.xlabel(r'Mutation level')
        plt.ylabel(r'Average similarity (not normalised)')
        plt.legend(loc='lower left')
        plt.savefig(aplot, transparent=True, bbox_inches='tight')
        plt.close()
项目:astetik    作者:mikkokotila    | 项目源码 | 文件源码
def kde(x,y,title='',color='YlGnBu',xscale='linear',yscale='linear'):

    sns.set_style('white')
    sns.set_context('notebook', font_scale=1, rc={"lines.linewidth": 0.5})
    g = sns.kdeplot(x,y,shade=True, cut=2, cmap=color, shade_lowest=False, legend=True, set_title="test")
    plt.tick_params(axis='both', which='major', pad=10)
    sns.plt.title(title)

    g.set(xscale=xscale)
    g.set(yscale=yscale)

    sns.despine()
项目:astetik    作者:mikkokotila    | 项目源码 | 文件源码
def bubble(data,x,y,hue,bsize,palette='Reds',xscale='linear',yscale='linear',title='',suptitle=0):

    if suptitle == 0:
        suptitle = bsize

    sns.set(style="whitegrid")
    sns.set_context("notebook", font_scale=3, rc={"lines.linewidth": 0.3})

    sns.set_color_codes("bright")

    size = (1500 / float(data[bsize].max()))
    size = data[bsize] * size

    g = sns.PairGrid(data, hue=hue, palette=palette, y_vars=y, x_vars=x, size=12, aspect=3)

    g.map(plt.scatter, s=size);

    g.set(xscale=xscale)
    g.set(yscale=yscale)

    g.add_legend(title=hue, bbox_to_anchor=(0.9, 0.6))

    plt.title(title, fontsize=48, y=1.12, color="gray");

    plt.suptitle("size = " + suptitle, verticalalignment='top', fontsize=35, y=1.01, x=0.48, color="gray")
    plt.xlabel(x, fontsize=38, labelpad=30, color="gray");
    plt.ylabel(y, fontsize=38, labelpad=30, color="gray");

    plt.tick_params(axis='both', which='major', pad=25)

    plt.axhline(linewidth=2.5, color="black");
    plt.axvline(linewidth=2.5, color="black");
    plt.ylim(0,);
    plt.xlim(0,);
项目:dueling-network    作者:musyoku    | 项目源码 | 文件源码
def plot_episode_reward():
    pylab.clf()
    sns.set_context("poster")
    pylab.plot(0, 0)
    episodes = [0]
    scores = [0]
    for n in xrange(len(csv_episode)):
        params = csv_episode[n]
        episodes.append(params[0])
        scores.append(params[1])
    pylab.plot(episodes, scores, sns.xkcd_rgb["windows blue"], lw=2)
    pylab.xlabel("episodes")
    pylab.ylabel("score")
    pylab.savefig("%s/episode_reward.png" % args.plot_dir)
项目:dueling-network    作者:musyoku    | 项目源码 | 文件源码
def plot_training_episode_highscore():
    pylab.clf()
    sns.set_context("poster")
    pylab.plot(0, 0)
    episodes = [0]
    highscore = [0]
    for n in xrange(len(csv_training_highscore)):
        params = csv_training_highscore[n]
        episodes.append(params[0])
        highscore.append(params[1])
    pylab.plot(episodes, highscore, sns.xkcd_rgb["windows blue"], lw=2)
    pylab.xlabel("episodes")
    pylab.ylabel("highscore")
    pylab.savefig("%s/training_episode_highscore.png" % args.plot_dir)
项目:datawatch    作者:WideOpen    | 项目源码 | 文件源码
def update_graph(dff):
    sns.set_style("white")
    sns.set_style("ticks")

    sns.set_context("talk")
    dff.ix[::10].plot("date", "overdue", figsize=(7, 4), lw=3)
    onemonth = datetime.timedelta(30)
    plt.xlim(dff.date.min(), dff.date.max()+onemonth)
    plt.ylabel("Overdue dataset")
    plt.xlabel("Date")
    plt.savefig("docs/graph.png")
项目:pygcam    作者:JGCRI    | 项目源码 | 文件源码
def setupPlot(context="talk", style="white", font_scale=1.0):
    sns.set_context(context, font_scale=font_scale)
    sns.set_style(style)
项目:jira-metrics-extract    作者:rnwolf    | 项目源码 | 文件源码
def set_context(context="talk"):
    sns.set_context(context)
项目:icing    作者:slipguru    | 项目源码 | 文件源码
def analyse(sm, labels, root='', plotting_context=None, file_format='pdf',
            force_silhouette=False, threshold=None):
    """Perform analysis.

    Parameters
    ----------
    sm : array, shape = [n_samples, n_samples]
        Precomputed similarity matrix.
    labels : array, shape = [n_samples]
        Association of each sample to a clsuter.
    root : string
        The root path for the output creation.
    plotting_context : dict, None, or one of {paper, notebook, talk, poster}
        See seaborn.set_context().
    file_format : ('pdf', 'png')
        Choose the extension for output images.
    """
    sns.set_context(plotting_context)

    if force_silhouette or sm.shape[0] < 8000:
        silhouette.plot_clusters_silhouette(1. - sm.toarray(), labels,
                                            max(labels), root=root,
                                            file_format=file_format)
    else:
        logging.warn(
            "Silhouette analysis is not performed due to the "
            "matrix dimensions. With a matrix %ix%i, you would need to "
            "allocate %.2fMB in memory. If you know what you are doing, "
            "specify 'force_silhouette = True' in the config file in %s, "
            "then re-execute the analysis.\n", sm.shape[0], sm.shape[0],
            sm.shape[0]**2 * 8 / (2.**20), root)

    # Generate dendrogram
    import scipy.spatial.distance as ssd
    Z = hierarchy.linkage(ssd.squareform(1. - sm.toarray()), method='complete',
                          metric='euclidean')

    plt.close()
    fig, (ax) = plt.subplots(1, 1)
    fig.set_size_inches(20, 15)
    hierarchy.dendrogram(Z, ax=ax)
    ax.axhline(threshold, color="red", linestyle="--")
    plt.show()
    filename = os.path.join(root, 'dendrogram_{}.{}'
                                  .format(extra.get_time(), file_format))
    fig.savefig(filename)
    logging.info('Figured saved %s', filename)

    plt.close()
    fig, (ax) = plt.subplots(1, 1)
    fig.set_size_inches(20, 15)
    plt.hist(1. - sm.toarray(), bins=50, normed=False)
    plt.ylim([0, 10])
    fig.savefig(filename + "_histogram_distances.pdf")
项目:astetik    作者:mikkokotila    | 项目源码 | 文件源码
def bubble(data,x,y,hue,bsize,palette='Reds',xscale='linear',yscale='linear',title='',suptitle=0,xlim_start=0,ylim_start=0,xlim_end=0,ylim_end=0):

    import seaborn as sns
    import matplotlib.pyplot as plt

    """ 

    x > should be int or float
    y > should be int or float
    hue > should be boolean or category
    bsize > should be int or float 


    """

    if suptitle == 0:
        suptitle = bsize    

    y_modifier = (data[y].max() - data[y].min()) * 0.1
    x_modifier = (data[x].max() - data[x].min()) * 0.1

    if ylim_start == 0:
        ylim_start = data[y].min()

    if xlim_start == 0:
        xlim_start = data[x].min()

    if ylim_end == 0:
        ylim_end = data[y].max() + y_modifier

    if xlim_end == 0:

        xlim_end = data[x].max() + (x_modifier * 2)

    sns.set(style="whitegrid")
    sns.set_context("notebook", font_scale=3, rc={"lines.linewidth": 0.3})

    sns.set_color_codes("bright")

    size = (1500 / float(data[bsize].max()))
    size = data[bsize] * size

    g = sns.PairGrid(data, hue=hue, palette=palette, y_vars=y, x_vars=x, size=12, aspect=3)
    g.map(plt.scatter, s=5000);
    g.set(xscale=xscale)
    g.set(yscale=yscale)
    g.add_legend(title=hue, bbox_to_anchor=(0.9, 0.6))

    plt.title(title, fontsize=48, y=1.12, color="gray");
    plt.suptitle("size = " + suptitle, verticalalignment='top', fontsize=35, y=1.01, x=0.48, color="gray")
    plt.xlabel(x, fontsize=38, labelpad=30, color="gray");
    plt.ylabel(y, fontsize=38, labelpad=30, color="gray");
    plt.tick_params(axis='both', which='major', pad=25)
    plt.axhline(linewidth=2.5, color="black");
    plt.axvline(linewidth=2.5, color="black");
    plt.ylim(ylim_start,ylim_end);
    plt.xlim(xlim_start,xlim_end);
项目:astetik    作者:mikkokotila    | 项目源码 | 文件源码
def bars(data,color='black',title=''):

    data = pd.DataFrame(data.value_counts())
    data = data.reset_index()
    data.columns = ['keyword','value']
    data['keyword'] = data['keyword'][1:]
    data = data.dropna()
    data = data.reset_index(drop=True)
    data = data.sort_values('value',ascending=False)

    sns.set_context("notebook", font_scale=1.2, rc={"lines.linewidth": 0})

    x = data.head(20)['keyword'].astype(str)
    y = data.head(20)['value'].astype(int)

    f, ax = plt.subplots(figsize=(16, 3))

    sns.set_style('white')

    ## change color of the bar based on value

    colors = [color if _y >=0 else 'red' for _y in y]

    sns.barplot(x, y, palette=colors, ax=ax)

    plt.title(title, fontsize=18, y=1.12, color="gray");

    ax.set_xticklabels('')
    ax.set_ylabel('')
    ax.set_xlabel('')
    ax.tick_params(axis='both', which='major', pad=30)

    for n, (label, _y) in enumerate(zip(x, y)):
        ax.annotate(
            s='{:.1f}'.format(abs(_y)),
            xy=(n, _y),
            ha='center',va='center',
            xytext=(0,-10),
            size=12,
            textcoords='offset points',
            color="white",
            weight="bold"
        )
    ax.set_yticklabels("");
    ax.set_xticklabels(data.head(20)['keyword'],rotation=25,ha="right");
    ax.tick_params(axis='both', which='major', pad=15)
    sns.despine(left=True)
项目:pygcam    作者:JGCRI    | 项目源码 | 文件源码
def plotForcingSubplots(tsdata, filename=None, ci=95, show_figure=False, save_fig_kwargs=None):
    sns.set_context('paper')
    expList = tsdata['expName'].unique()

    nrows = 1
    ncols = len(expList)
    width  = 2 * ncols
    height = 2
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, sharey=True, figsize=(width, height))

    def dataForExp(expName):
        df = tsdata.query("expName == '%s'" % expName).copy()
        df.drop(['expName'], axis=1, inplace=True)
        df = pd.melt(df, id_vars=['runId'], var_name='year')
        return df

    for ax, expName in zip(axes, expList):
        df = dataForExp(expName)

        pos = expName.find('-')
        title = expName[:pos] if pos >= 0 else expName
        ax.set_title(title.capitalize())

        tsm.tsplot(df, time='year', unit='runId', value='value', ci=ci, ax=ax)

        ylabel = 'W m$^{-2}$' if ax == axes[0] else ''
        ax.set_ylabel(ylabel)
        ax.set_xlabel('') # no need to say "year"
        ax.axhline(0, color='navy', linewidth=0.5, linestyle='-')
        plt.setp(ax.get_xticklabels(), rotation=270)

    plt.tight_layout()

    # Save the file
    if filename:
        if isinstance(save_fig_kwargs, dict):
            fig.savefig(filename, **save_fig_kwargs)
        else:
            fig.savefig(filename)

    # Display the figure
    if show_figure:
        plt.show()

    return fig