Calculate a minimum spanning tree with Python

This article presents how to calculate a minimum spanning tree with the Python package, NetworkX.

1. Spanning tree

spanning tree T of an undirected graph G is a subgraph that is a tree which includes all of the vertices of G. If a graph is not connected, it has spanning forests, but no spanning tree. A minimum spanning tree (MST) is a spanning tree whose sum of edge weights is minimum. A maximum leaf spanning tree is a spanning tree that has the largest possible number of leaves among all spanning trees of G.

1.1 Applications

1.2. An example

NetworkX provides interfaces to return a minimum spanning tree or its edges.

As an example, we calculate a minimum spanning tree on Florentine families graph. (PS: the complete source code is hosted on my GitHub, here).

import networkx as nx

# build up a graph
filename = '../../florentine_families_graph.gpickle'
G = nx.read_gpickle(filename)

# Calculate minimum spanning tree
mst = nx.minimum_spanning_tree(G) 

As shown below, the edges in the minimum spanning tree are highlighted in red.

Fig. 1: The minimum spanning tree in Florentine families graph

2. Kruskal’s algorithm

Kruskal’s algorithm is a minimum-spanning-tree algorithm. It is a greedy algorithm as it finds a minimum spanning tree for a connected weighted graph by adding increasing cost arcs at each step. The proof consists of two parts. First, it is proved that the algorithm produces a spanning tree. Second, it is proved that the constructed spanning tree is of minimal weight. It can be shown to run in \(O(E \cdot log E)\) (or equivalently \(O(E \cdot log V)\)) time, all with simple data structures [2].

2.1 The source code

nx.minimum_spanning_tree is based on nx.minimum_spanning_edges, which implements Kruskal’s algorithm.

import networkx as nx

def minimum_spanning_edges(G, weight='weight', data=True):
    from networkx.utils import UnionFind

    if G.is_directed():
        raise nx.NetworkXError(
            "Mimimum spanning tree not defined for directed graphs.")

    subtrees = UnionFind()
    edges = sorted(G.edges(data=True), key=lambda t: t[2].get(weight, 1))   # sort by edge weight in ascending order
    for u, v, d in edges:
        if subtrees[u] != subtrees[v]:  # Find set, u and v belongs to different set.
            if data:
                yield (u, v, d)
                yield (u, v)
            subtrees.union(u, v)

Union-find data structure, networkx.utils.UnionFind

  • [X[item]] returns a name for the set containing the given item.
  • UnionFind.union(*objects) merges the sets containing each item into a single larger set.

[1] Wikipedia: Spanning tree
[2] Wikipedia: Kruskal’s algorithm
[3] Wikipedia: Union–find data structure or Disjoint-set data structure


电子邮件地址不会被公开。 必填项已用*标注