Python scipy.cluster.hierarchy 模块,dendrogram() 实例源码

我们从Python开源项目中,提取了以下24个代码示例,用于说明如何使用scipy.cluster.hierarchy.dendrogram()

项目:lddmm-ot    作者:jeanfeydy    | 项目源码 | 文件源码
def set_figure_layout(self, width, height):
        """
        Sets and returns default layout object for dendrogram figure.

        """
        self.layout.update({
            'showlegend': False,
            'autosize': False,
            'hovermode': 'closest',
            'width': width,
            'height': height
        })

        self.set_axis_layout(self.xaxis)
        self.set_axis_layout(self.yaxis)

        return self.layout
项目:ccCluster    作者:gsantoni    | 项目源码 | 文件源码
def avgTree(self):
        data = self.ccTable
        Matrix=np.zeros((self.Dimension,self.Dimension))

        reducedArray=[]
        for line in data:
                #print line
            if line is not None and len(line) is not 0:
                 Matrix[line[0],line[1]]= line[2]
                 Matrix[line[1],line[0]]= line[2]


        for x in range(0,self.Dimension):
            for y in range(x+1,self.Dimension):
                reducedArray.append(Matrix[x,y])

        Distances = np.array(reducedArray, dtype=(float))
        self.Tree =hierarchy.linkage(Distances, 'average')

        return self.Tree

#Funtion added to plot dendrogram in shell mode only.
#still not funtioninhg
#Uncomment when will be needed
项目:ccCluster    作者:gsantoni    | 项目源码 | 文件源码
def plotTree( self, pos=None):
        P = hierarchy.dendrogram(self.Tree, color_threshold=0.3)
        icoord = scipy.array( P['icoord'] )
        dcoord = scipy.array( P['dcoord'] )
        color_list = scipy.array( P['color_list'] )
        xmin, xmax = icoord.min(), icoord.max()
        ymin, ymax = dcoord.min(), dcoord.max()
        if pos:
            icoord = icoord[pos]
            ioord = dcoord[pos]
            color_list = color_list[pos]
        for xs, ys, color in zip(icoord, dcoord, color_list):
            plt.plot(xs, ys,  color)
        plt.xlim( xmin-10, xmax + 0.1*abs(xmax) )
        plt.ylim( ymin, ymax + 0.1*abs(ymax) )
        plt.show()
项目:aesop    作者:BioMoDeL    | 项目源码 | 文件源码
def plotESD(esd, filename=None, cmap='hot'):
    """Summary
    Function to display an electrostatic similarity heatmap from a previously
    run ElecSimilarity class.

    Parameters
    ----------
    esd : ndarray
        ESD matrix from ElecSimilarity class (ElecSimilarity.esd).
    filename : str, optional
        If the resulting plot should be written to disk, specify a filename.
        Otherwise, the image will only be saved.
    cmap : str, optional
        Colormap from matplotlib to use.

    Returns
    -------
    None
        Writes image to disk, if desired.
    """
    # plt.style.use('seaborn-talk')
    fig, ax = plt.subplots(sharey=True)
    heatmap = ax.pcolor(esd.esd, cmap=cmap, vmin=0, vmax=2)
    ax.set_xlim(0, esd.esd.shape[0])
    ax.set_ylim(0, esd.esd.shape[1])
    ax.set_xticks(np.arange(esd.esd.shape[0]) + 0.5, minor=False)
    ax.set_yticks(np.arange(esd.esd.shape[1]) + 0.5, minor=False)
    ax.set_xticklabels(esd.ids, rotation=90)
    ax.set_yticklabels(esd.ids)
    fig.colorbar(heatmap)
    plt.tight_layout()
    if filename is not None:
        fig.savefig(filename)


##########################################################################
# Function to plot ESD dendrogram
##########################################################################
项目:aesop    作者:BioMoDeL    | 项目源码 | 文件源码
def plotDend(esd, filename=None):
    """Summary
    Function to display an electrostatic similarity dendrogram from a
    previously run ElecSimilarity class.

    Parameters
    ----------
    esd : ElecSimilarity class
        ElecSimilarity class containing final esd matrix.
    filename : str, optional
        If the resulting plot should be written to disk, specify a filename.
        Otherwise, the image will only be saved.

    Returns
    -------
    None
        Writes image to disk, if desired.
    """
    # plt.style.use('seaborn-talk')
    fig, ax = plt.subplots(sharey=True)
    Z = cluster.linkage(esd.esd)
    cluster.dendrogram(
        Z,
        labels=esd.ids,
        leaf_rotation=90.,  # rotates the x axis labels
        leaf_font_size=8.,  # font size for the x axis labels
        ax=ax)
    plt.xlabel('Variants')
    plt.ylabel('ESD')
    plt.tight_layout()
    if filename is not None:
        fig.savefig(filename)
项目:word2vec_pipeline    作者:NIHOPA    | 项目源码 | 文件源码
def docv_centroid_order_idx(meta_clusters):
    dist = cdist(meta_clusters, meta_clusters, metric='cosine')

    # Compute the linkage and the order
    linkage = hierarchy.linkage(dist, method='average')
    d_idx = hierarchy.dendrogram(linkage, no_plot=True)["leaves"]

    return d_idx
项目:nd_array    作者:KwatME    | 项目源码 | 文件源码
def cluster_2d_array_rows(array_2d,
                          linkage_method='average',
                          distance_function='euclidean'):
    """
    Cluster array_2d rows.
    Arguments:
        array_2d (array): (n_rows, n_columns)
        linkage_method (str): linkage method compatible for
            scipy.cluster.hierarchy.linkage
        distance_function (str | callable): distance function compatible for
            scipy.cluster.hierarchy.linkage
    Returns:
        array: (n_rows); clustered row indices
    """

    clustered_indices = dendrogram(
        linkage(array_2d, method=linkage_method, metric=distance_function),
        no_plot=True)['leaves']

    return array(clustered_indices)
项目:FreeDiscovery    作者:FreeDiscovery    | 项目源码 | 文件源码
def test_denrogram_children():
    # temporary solution for
    # https://stackoverflow.com/questions/40239956/node-indexing-in-hierarachical-clustering-dendrograms
    import numpy as np
    from scipy.cluster.hierarchy import dendrogram, linkage
    from freediscovery.cluster import _DendrogramChildren

    # generate two clusters: a with 10 points, b with 5:
    np.random.seed(1)
    a = np.random.multivariate_normal([10, 0], [[3, 1], [1, 4]],
                                      size=[10, ])
    b = np.random.multivariate_normal([0, 20], [[3, 1], [1, 4]],
                                      size=[5, ])
    X = np.concatenate((a, b),)
    Z = linkage(X, 'ward')
    # make distances between pairs of children uniform
    # (re-scales the horizontal (distance) axis when plotting)
    Z[:, 2] = np.arange(Z.shape[0])+1

    ddata = dendrogram(Z, no_plot=True)
    dc = _DendrogramChildren(ddata)
    idx = 0
    # check that we can compute children for all nodes
    for i, d, c in zip(ddata['icoord'], ddata['dcoord'], ddata['color_list']):
        node_children = dc.query(idx)
        idx += 1
    # last level node should encompass all samples
    assert len(node_children) == X.shape[0]
    assert_allclose(sorted(node_children), np.arange(X.shape[0]))
项目:lddmm-ot    作者:jeanfeydy    | 项目源码 | 文件源码
def get_color_dict(self, colorscale):
        """
        Returns colorscale used for dendrogram tree clusters.

        :param (list) colorscale: Colors to use for the plot in rgb format.
        :rtype (dict): A dict of default colors mapped to the user colorscale.

        """

        # These are the color codes returned for dendrograms
        # We're replacing them with nicer colors
        d = {'r': 'red',
             'g': 'green',
             'b': 'blue',
             'c': 'cyan',
             'm': 'magenta',
             'y': 'yellow',
             'k': 'black',
             'w': 'white'}
        default_colors = OrderedDict(sorted(d.items(), key=lambda t: t[0]))

        if colorscale is None:
            colorscale = [
                'rgb(0,116,217)',  # blue
                'rgb(35,205,205)',  # cyan
                'rgb(61,153,112)',  # green
                'rgb(40,35,35)',  # black
                'rgb(133,20,75)',  # magenta
                'rgb(255,65,54)',  # red
                'rgb(255,255,255)',  # white
                'rgb(255,220,0)']  # yellow

        for i in range(len(default_colors.keys())):
            k = list(default_colors.keys())[i]  # PY3 won't index keys
            if i < len(colorscale):
                default_colors[k] = colorscale[i]

        return default_colors
项目:lddmm-ot    作者:jeanfeydy    | 项目源码 | 文件源码
def set_axis_layout(self, axis_key):
        """
        Sets and returns default axis object for dendrogram figure.

        :param (str) axis_key: E.g., 'xaxis', 'xaxis1', 'yaxis', yaxis1', etc.
        :rtype (dict): An axis_key dictionary with set parameters.

        """
        axis_defaults = {
                'type': 'linear',
                'ticks': 'outside',
                'mirror': 'allticks',
                'rangemode': 'tozero',
                'showticklabels': True,
                'zeroline': False,
                'showgrid': False,
                'showline': True,
            }

        if len(self.labels) != 0:
            axis_key_labels = self.xaxis
            if self.orientation in ['left', 'right']:
                axis_key_labels = self.yaxis
            if axis_key_labels not in self.layout:
                self.layout[axis_key_labels] = {}
            self.layout[axis_key_labels]['tickvals'] = \
                [zv*self.sign[axis_key] for zv in self.zero_vals]
            self.layout[axis_key_labels]['ticktext'] = self.labels
            self.layout[axis_key_labels]['tickmode'] = 'array'

        self.layout[axis_key].update(axis_defaults)

        return self.layout[axis_key]
项目:ECoG-ClusterFlow    作者:sugeerth    | 项目源码 | 文件源码
def HierarchicalClustering(self,data):
        distances = nx.to_numpy_matrix(data)
        hierarchy = linkage(distances)
        print hierarchy,"HIERRATCJY"
        Z = dendrogram(hierarchy)
        print Z
        return hierarchy
项目:lens    作者:ASIDataScience    | 项目源码 | 文件源码
def hierarchical_ordering_indices(columns, correlation_matrix):
    """Return array with hierarchical cluster ordering of columns

    Parameters
    ----------
    columns: iterable of str
        Names of columns.
    correlation_matrix: np.ndarray
        Matrix of correlation coefficients between columns.

    Returns
    -------
    indices: iterable of int
        Indices with order of columns
    """
    if len(columns) > 2:
        pairwise_dists = distance.pdist(
            np.where(np.isnan(correlation_matrix), 0, correlation_matrix),
            metric='euclidean')
        linkage = hierarchy.linkage(pairwise_dists, method='average')
        dendogram = hierarchy.dendrogram(
            linkage, no_plot=True, color_threshold=-np.inf)
        idx = dendogram['leaves']
    else:
        idx = list(range(len(columns)))

    return idx
项目:sdaopt    作者:sgubianpm    | 项目源码 | 文件源码
def heat_map_reliability(data):
    nb_runs = get_data_info(data)
    nb_func = len(get_data_info(data, 'fnames'))
    methods = get_data_info(data, 'methods')
    mat = np.empty([nb_func, len(methods)])
    for j, k in enumerate(data):
        for i, f in enumerate(data[k]):
            success = np.sum(
                data[k][f]['success']) * 100 / nb_runs
            if np.isnan(success):
                success = 0
            mat[i, j] = success
    # mat.sort(axis=0)
    fig = plt.figure()
    ax1 = fig.add_axes([0.7, 0.1, 0.18, 0.8])
    Y = fastcluster.linkage(mat, method='ward')
    Z1 = sch.dendrogram(Y, orientation='right')
    ax1.set_xticks([])
    ax1.set_yticks([])
    axmatrix = fig.add_axes([0.1, 0.1, 0.6, 0.8])
    axmatrix.set_title(
            'Success rate across test functions (reliability over 200 runs)')
    im = axmatrix.matshow(
            mat[Z1['leaves'], :], aspect='auto', origin='lower',
            cmap=plt.cm.RdYlGn,
            )
    methods.insert(0, ' ')
    axmatrix.set_xticklabels(methods)
    # Reorder functions indexes:
    c = np.arange(0, nb_func)[Z1['leaves']]
    axmatrix.set_yticks(np.arange(0, nb_func, 10))
    axmatrix.set_yticklabels(c)
    axmatrix.set_ylabel('Test function number')
    axcolor = fig.add_axes([0.9, 0.1, 0.02, 0.8])
    plt.colorbar(im, cax=axcolor)
    # fig.show()
    fig.save('heatmap.pdf', bbox_inches='tight', format='pdf')
项目:text-analytics-with-python    作者:dipanjanS    | 项目源码 | 文件源码
def plot_hierarchical_clusters(linkage_matrix, movie_data, figure_size=(8,12)):
    # set size
    fig, ax = plt.subplots(figsize=figure_size) 
    movie_titles = movie_data['Title'].values.tolist()
    # plot dendrogram
    ax = dendrogram(linkage_matrix, orientation="left", labels=movie_titles)
    plt.tick_params(axis= 'x',   
                    which='both',  
                    bottom='off',
                    top='off',
                    labelbottom='off')
    plt.tight_layout()
    plt.savefig('ward_hierachical_clusters.png', dpi=200)

# build ward's linkage matrix
项目:fri    作者:lpfann    | 项目源码 | 文件源码
def plot_dendrogram_and_intervals(intervals,linkage,threshold=0.55,ticklabels=None):
    z = linkage
    fig = plt.figure(figsize=(13, 6))

    # Top dendrogram plot
    ax2 = fig.add_subplot(211)
    d = dendrogram(
        z,
        color_threshold=threshold,
        leaf_rotation=0.,  # rotates the x axis labels
        leaf_font_size=12.,  # font size for the x axis labels
        ax=ax2
    )
    # Get index determined through linkage method and dendrogram
    rearranged_index = d['leaves']
    ranges = intervals[rearranged_index]

    ax = fig.add_subplot(212)
    N = len(ranges)
    if  ticklabels is None:
        ticks = np.array(rearranged_index)
        ticks +=1 # Index starting at 1
    else:
        ticks = list(ticklabels[rearranged_index])

    ind = np.arange(N)+1
    width = 0.6
    upper_vals = ranges[:,1]
    lower_vals = ranges[:,0]
    bars = ax.bar(ind, upper_vals - lower_vals, width,bottom=lower_vals,tick_label=ticks,align="center" ,linewidth=1.3)

    plt.ylabel('relevance',fontsize=19)
    plt.xlabel('feature',fontsize=19)
    plt.xticks(ind,ticks, rotation='vertical')
    ax.margins(x=0)
    ax2.set_xticks([])
    ax2.margins(x=0)
    plt.tight_layout()
    #plt.subplots_adjust(wspace=0, hspace=0)

    return fig
项目:HiCembler    作者:lpryszcz    | 项目源码 | 文件源码
def getNewick(node, newick, parentdist, leaf_names):
    """Return Newick representing dendrogram"""
    if node.is_leaf():
        return "%s:%.2f%s" % (leaf_names[node.id], parentdist - node.dist, newick)
    else:
        if len(newick) > 0:
            newick = "):%.2f%s" % (parentdist - node.dist, newick)
        else:
            newick = ");"
        newick = getNewick(node.get_left(), newick, node.dist, leaf_names)
        newick = getNewick(node.get_right(), ",%s" % (newick), node.dist, leaf_names)
        newick = "(%s" % (newick)
        return newick
项目:HiCembler    作者:lpryszcz    | 项目源码 | 文件源码
def plot_dendro(fn, iZ):
    """Plot dendrogram"""
    plt.figure(figsize=(25, 10))
    plt.title('Hierarchical Clustering Dendrogram')
    plt.xlabel('sample index')
    plt.ylabel('distance')
    sch.dendrogram(iZ, leaf_rotation=90., leaf_font_size=8.)
    outfn = "%s.png"%fn
    if type(fn) is str and len(fn.split('.')[-1])==3:
        outfn = fn
    plt.savefig(outfn)
项目:rnnlab    作者:phueb    | 项目源码 | 文件源码
def make_multi_hierarch_cluster_figs(model, field_input):
    def make_multi_cat_clust_fig(cats, freq_thr=500):  # TODO make into config
        """
        Returns fig showing hierarchical clustering of probes from multiple categories
        """
        start = time.time()
        sns.set_style('white')
        # make cat_acts_mat
        acts_mats = []
        cats_probe_list = []
        for cat in cats:
            bool_index = [True if sum(model.term_doc_freq_dict[probe]) > freq_thr else False
                          for probe in model.probe_store.cat_probe_list_dict[cat]]

            cat_probe_acts_df = model.get_single_cat_acts_df(cat)
            filtered_cat_probes_acts_mat = cat_probe_acts_df[bool_index].values
            acts_mats.append(filtered_cat_probes_acts_mat)
            cats_probe_list += [model.probe_store.probe_set[probe_id]
                                for probe_id in cat_probe_acts_df[bool_index].index.tolist()]
        cat_acts_mat = np.vstack((mat for mat in acts_mats))
        # fig
        rcParams['lines.linewidth'] = 2.0
        fig, ax = plt.subplots(figsize=(FigsConfigs.MAX_FIG_WIDTH, 5 * len(cats)), dpi=FigsConfigs.DPI)
        # dendrogram
        dist_matrix = pdist(cat_acts_mat, 'euclidean')
        linkages = linkage(dist_matrix, method='complete')
        dendrogram(linkages,
                   ax=ax,
                   labels=cats_probe_list,
                   orientation='right',
                   leaf_font_size=10)
        ax.tick_params(axis='both', which='both', top='off', right='off', left='off')
        ax.spines['right'].set_visible(False)
        ax.spines['left'].set_visible(False)
        ax.spines['top'].set_visible(False)
        print('{} completed in {:.1f} secs'.format(sys._getframe().f_code.co_name, time.time() - start))
        return fig

    figs = [make_multi_cat_clust_fig(field_input)]
    return figs
项目:ccCluster    作者:gsantoni    | 项目源码 | 文件源码
def createDendrogram(self):
        X = hierarchy.dendrogram(Tree, color_threshold=self.threshold)
        #self.textOutput.append('Plotted Dendrogram. Colored at a %s threshold for distance'%(threshold))
        self.TreeCanvas.draw()
项目:IgDiscover    作者:NBISweden    | 项目源码 | 文件源码
def main(args):
    with FastaReader(args.fasta) as fr:
        sequences = list(fr)
    logger.info('Plotting dendrogram of %s sequences', len(sequences))
    if args.mark:
        with FastaReader(args.mark) as fr:
            mark = PrefixComparer(record.sequence for record in fr)
        labels = []
        n_new = 0
        for record in sequences:
            if record.sequence not in mark:
                extra = ' (new)'
                n_new += 1
            else:
                extra = ''
            labels.append(record.name + extra)
        logger.info('%s sequence(s) marked as "new"', n_new)
    else:
        labels = [s.name for s in sequences]
    sns.set_style("white")
    font_size = 297 / 25.4 * 72 / (len(labels) + 5)
    font_size = min(16, max(6, font_size))
    height = font_size * (len(labels) + 5) / 72
    fig = plt.figure(figsize=(210 / 25.4, height))
    matplotlib.rcParams.update({'font.size': 4})
    ax = fig.gca()
    sns.despine(ax=ax, top=True, right=True, left=True, bottom=True)
    sns.set_style('whitegrid')
    if len(sequences) >= 2:
        m = distances([s.sequence for s in sequences])
        y = distance.squareform(m)
        mindist = int(y.min())
        logger.info('Smallest distance is %s. Found between:', mindist)
        for i,j in np.argwhere(m == y.min()):
            if i < j:
                logger.info('%s and %s', labels[i], labels[j])
        l = hierarchy.linkage(y, method=args.method)
        hierarchy.dendrogram(l, labels=labels, leaf_font_size=font_size, orientation='right', color_threshold=0.95*max(l[:,2]))
    else:
        ax.text(0.5, 0.5, 'no sequences', fontsize='xx-large')
    ax.grid(False)
    fig.set_tight_layout(True)
    fig.savefig(args.plot)
项目:icing    作者:slipguru    | 项目源码 | 文件源码
def analyse(sm, labels, root='', plotting_context=None, file_format='pdf',
            force_silhouette=False, threshold=None):
    """Perform analysis.

    Parameters
    ----------
    sm : array, shape = [n_samples, n_samples]
        Precomputed similarity matrix.
    labels : array, shape = [n_samples]
        Association of each sample to a clsuter.
    root : string
        The root path for the output creation.
    plotting_context : dict, None, or one of {paper, notebook, talk, poster}
        See seaborn.set_context().
    file_format : ('pdf', 'png')
        Choose the extension for output images.
    """
    sns.set_context(plotting_context)

    if force_silhouette or sm.shape[0] < 8000:
        silhouette.plot_clusters_silhouette(1. - sm.toarray(), labels,
                                            max(labels), root=root,
                                            file_format=file_format)
    else:
        logging.warn(
            "Silhouette analysis is not performed due to the "
            "matrix dimensions. With a matrix %ix%i, you would need to "
            "allocate %.2fMB in memory. If you know what you are doing, "
            "specify 'force_silhouette = True' in the config file in %s, "
            "then re-execute the analysis.\n", sm.shape[0], sm.shape[0],
            sm.shape[0]**2 * 8 / (2.**20), root)

    # Generate dendrogram
    import scipy.spatial.distance as ssd
    Z = hierarchy.linkage(ssd.squareform(1. - sm.toarray()), method='complete',
                          metric='euclidean')

    plt.close()
    fig, (ax) = plt.subplots(1, 1)
    fig.set_size_inches(20, 15)
    hierarchy.dendrogram(Z, ax=ax)
    ax.axhline(threshold, color="red", linestyle="--")
    plt.show()
    filename = os.path.join(root, 'dendrogram_{}.{}'
                                  .format(extra.get_time(), file_format))
    fig.savefig(filename)
    logging.info('Figured saved %s', filename)

    plt.close()
    fig, (ax) = plt.subplots(1, 1)
    fig.set_size_inches(20, 15)
    plt.hist(1. - sm.toarray(), bins=50, normed=False)
    plt.ylim([0, 10])
    fig.savefig(filename + "_histogram_distances.pdf")
项目:TTClust    作者:tubiana    | 项目源码 | 文件源码
def init_log(args, mdtrajectory):
    """
    DESCRIPTION
    initialyse the logfile with some information
    ----
    Args:
        args (dict): dictionnary of all arguments (argparse)
    """
    topo = args["top"]
    traj = args["traj"]
    selection_string = args["select_traj"]
    select_align = args["select_alignement"]
    select_rmsd = args["select_rmsd"]
    logname = os.path.splitext(args["logfile"])[0]


    LOGFILE.write("========================================================\n")
    LOGFILE.write("====================  TTCLUST {}  ===================\n"\
        .format(__version__))
    LOGFILE.write("========================================================\n")
    LOGFILE.write("\n")


    LOGFILE.write("************ General information ************\n")
    LOGFILE.write("software version   : {}\n".format(__version__))
    LOGFILE.write("Created on         : {}\n".format(datetime.datetime.now()))
    write_command_line()
    LOGFILE.write("DESTINATION FOLDER : {}\n".format(os.getcwd()+"/"+logname))
    LOGFILE.write("ARGUMENTS : \n")
    LOGFILE.write("  Selection string :\n")
    LOGFILE.write("      Atoms selected in trajectory = {} \n".format(
                                                        selection_string))
    LOGFILE.write("      Atoms selected for alignement = {} \n".format(
                                                        select_align))
    LOGFILE.write("      Atoms selected for RMSD = {} \n".format(select_rmsd))
    LOGFILE.write("  trajectory file  : {} \n".format(traj))
    LOGFILE.write("   Number of frames  : {} \n".format(mdtrajectory.n_frames))
    LOGFILE.write("   Number of atoms  : {} \n".format(mdtrajectory.n_atoms))
    LOGFILE.write("  topology file    : {} \n".format(topo))
    LOGFILE.write("  method used of clusterring : {}".format(args["method"]))
    LOGFILE.write("\n\n")
    if args["ngroup"]:
        LOGFILE.write("  Number of cluster asked: {}\n".format(args["ngroup"]))
    if args["cutoff"]:
        LOGFILE.write("  cutoff for dendrogram clustering: {}\n".format("cutoff"))
项目:TTClust    作者:tubiana    | 项目源码 | 文件源码
def plot_dendro(linkage, logname, cutoff, color_list,clusters_list):
    """
    DESCRIPTION
    This function will create the dendrogram graph with the corresponding
    cluster color.
    Args:
        linkage (numpy array) : linkage matrix
        output (str) : output logfile name
        cutoff (float) : cutoff used for clustering
        color_list (list) : HEX code color for each cluster
        cluster_list (list) : list of all cluster (Cluster object)
    Returns:
        None
    """
    if mpl.__version__[0] == "2":
        STYLE = "classic"
        if STYLE in plt.style.available:
            plt.style.use(STYLE)
    fig = plt.figure()
    #Convert RGB color to HEX color
    color_hex = [mpl.colors.rgb2hex(x) for x in color_list]
    sch.set_link_color_palette(color_hex)
    #clusters_list
    color_member = {}
    for cl in clusters_list:
        for frm in cl.frames:
            color_member[frm] = mpl.colors.rgb2hex(color_list[cl.id-1])

    #Attribute the correct color for each branch.
    #adapte from Ulrich Stern code in StackOverflow http://stackoverflow.com/a/38208611
    link_cols = {}
    for i, i12 in enumerate(linkage[:,:2].astype(int)):
        c1, c2 = (link_cols[x] if x > len(linkage) else color_member[x] for x in i12)
        link_cols[i+1+len(linkage)] = c1 if c1 == c2 else "#808080"

    #Dendrogram creation
    # Override the default linewidth.
    den = sch.dendrogram(linkage, color_threshold=float(cutoff), above_threshold_color="#808080", link_color_func=lambda x: link_cols[x])

    #Graph parameters
    plt.title("Clustering Dendrogram")
    ax = plt.axes()
    ax.set_xticklabels([])
    plt.axhline(y=float(cutoff), color = "grey") # cutoff value vertical line
    ax.set_ylabel("Distance (AU)")
    ax.set_xlabel("Frames")

    plt.savefig("{0}/{0}-den.png".format(logname), format="png", dpi=DPI, transparent=True)
    plt.close()
项目:rnnlab    作者:phueb    | 项目源码 | 文件源码
def make_hierarch_cluster_figs(model):
    def make_cat_cluster_fig(cat, bottom_off=False, num_probes_limit=20):
        """
        Returns fig showing hierarchical clustering of probes in single category
        """
        start = time.time()
        sns.set_style('white')
        # load data
        cat_prototypes_df = model.get_single_cat_acts_df(cat)
        probes_in_cat = [model.probe_store.probe_set[probe_id] for probe_id in cat_prototypes_df.index.tolist()]
        num_probes_in_cat = len(probes_in_cat)
        if num_probes_limit and num_probes_in_cat > num_probes_limit:
            ids = np.random.choice(range(num_probes_in_cat), num_probes_limit, replace=False)
            cat_prototypes_df = cat_prototypes_df.iloc[ids]
            probes_in_cat = [probes_in_cat[id] for id in ids]
        # fig
        rcParams['lines.linewidth'] = 2.0
        fig, ax = plt.subplots(figsize=(FigsConfigs.MAX_FIG_WIDTH, 4), dpi=FigsConfigs.DPI)
        # dendrogram
        dist_matrix = pdist(cat_prototypes_df.values, 'euclidean')
        linkages = linkage(dist_matrix, method='complete')
        dendrogram(linkages,
                   ax=ax,
                   leaf_label_func=lambda x: probes_in_cat[x],
                   orientation='right',
                   leaf_font_size=8)
        ax.set_title(cat)
        ax.set_xlim([0, FigsConfigs.CAT_CLUSTER_XLIM])
        ax.tick_params(axis='both', which='both', top='off', right='off', left='off')
        ax.spines['right'].set_visible(False)
        ax.spines['left'].set_visible(False)
        ax.spines['top'].set_visible(False)
        if bottom_off:
            ax.xaxis.set_ticklabels([])  # hides ticklabels
            ax.tick_params(axis='both', which='both', bottom='off')
            ax.spines['bottom'].set_visible(False)
        fig.tight_layout()
        print('{} completed in {:.1f} secs'.format(sys._getframe().f_code.co_name, time.time() - start))
        return fig

    figs = [make_cat_cluster_fig(cat) for cat in model.probe_store.cat_set]
    return figs