2026-06-23 · 9 min read

Why Tree All-Reduce Is 2·log₂(N), Not 2(N-1)

The ring's 2(N-1) grows linearly with core count — 510 steps at 256 cores. Tree all-reduce computes the identical sum but folds the cores up a binary tree and unfolds the result back down, costing 2·log₂(N). Tracing the same four-core matmul to show why the tree wins on latency for small tensors and loses on bandwidth for large ones.

All-ReduceCollective CommunicationDistributed TrainingGPU InfrastructureTree All-Reduce

The previous blog traced ring all-reduce and showed its cost is 2(N-1): one N-1 for reduce-scatter, one N-1 for all-gather. The ring is bandwidth optimal, but that 2(N-1) is a problem. The step count grows linearly with the number of cores. At 256 cores the ring takes 510 sequential steps before a single token comes out. For small tensors, where each step is mostly fixed latency rather than data movement, that linear growth is the whole cost.

Tree all-reduce attacks exactly that weakness. It computes the identical sum and lands the identical result on every core, but instead of passing data around a ring it folds the cores up a binary tree and unfolds the result back down. The cost is 2 log2(N) steps. For 256 cores that is 16 steps instead of 510. This post derives that 2 log2(N) the same way the ring post derived 2(N-1): start from a real matmul, compute the partial sums by hand, then trace every number up the tree and back down until the formula is something you can point at.

Same four cores, same distinct values, no zeros, because repeated numbers are impossible to follow when the point is watching them move. I reuse the exact matmul from the ring post so you can lay the two traces side by side.

Where the partial sums come from

One linear layer, Y = X @ W, one token. X is 1 x 4 and W is 4 x 4:

X = [ 3  2  1  3 ]        W = [ 5  4  3  2 ]   row 0
                              [ 1  4  3  2 ]   row 1
                              [ 1  3  4  2 ]   row 2
                              [ 4  2  1  5 ]   row 3

Computed normally, the answer is one 1 x 4 vector:

Y = X @ W = [ 30  29  22  27 ]

That is the target every core must end up holding.

Split across four cores along the contraction dimension. Row-parallel: W cut along its rows, X cut along its matching columns. Core i gets column i of X, a scalar, and row i of W, a length-4 vector:

Core 0:  X_0 = 3    W_0 = [ 5  4  3  2 ]
Core 1:  X_1 = 2    W_1 = [ 1  4  3  2 ]
Core 2:  X_2 = 1    W_2 = [ 1  3  4  2 ]
Core 3:  X_3 = 3    W_3 = [ 4  2  1  5 ]

Each core computes X_i @ W_i, the scalar times the row, and gets a full 1 x 4 partial sum:

Core 0:  3 * [ 5  4  3  2 ]  =  [ 15  12   9   6 ]
Core 1:  2 * [ 1  4  3  2 ]  =  [  2   8   6   4 ]
Core 2:  1 * [ 1  3  4  2 ]  =  [  1   3   4   2 ]
Core 3:  3 * [ 4  2  1  5 ]  =  [ 12   6   3  15 ]

Sum down the columns to confirm:

position 0:  15 +  2 +  1 + 12 = 30
position 1:  12 +  8 +  3 +  6 = 29
position 2:   9 +  6 +  4 +  3 = 22
position 3:   6 +  4 +  2 + 15 = 27

Target: [30, 29, 22, 27]

These four partial vectors are the input to the tree. The job is the same as any all-reduce: add the four vectors, leave every core holding [30, 29, 22, 27]. What changes is the routing.

The structural idea: fold, then unfold

The ring threads data through every core in sequence, which is why it pays N-1 per phase. The tree refuses to go through cores one at a time. It pairs them up and collapses pairs in parallel.

Think about adding four numbers. You can do it sequentially, a+b, then +c, then +d, which is three dependent steps. Or you can do (a+b) and (c+d) at the same time, then add the two results, which is two steps because the first two additions happen simultaneously. The tree is the second strategy applied to whole vectors sitting on different cores. Halve the number of live cores at every level, and you reach a single core holding the full sum in log2(N) levels instead of N-1.

That handles the reduction, which is data flowing inward to one root core. But the postcondition still demands every core hold the result, so after the root has the sum you reverse the tree and broadcast it back outward, doubling the live cores at each level, another log2(N) levels. Fold up in log2(N), unfold down in log2(N). The total is 2 log2(N), and the 2 is the same structural 2 as the ring: one factor for the reduction inward, one for the broadcast outward. The difference is only how many steps each direction costs.

Trace both directions.

Phase one: reduce-up, where log2(N) comes from

Arrange the four cores as a binary tree. Pair (0,1) and pair (2,3) at the bottom, with core 0 and core 2 as the two sub-roots, and core 0 as the overall root.

Level 1, distance 1. Within each pair, the odd core sends its whole vector to the even core, which adds it in. Both pairs do this at the same time, in one step.

Core 1 sends [ 2  8  6  4 ] to Core 0
Core 3 sends [ 12 6  3 15 ] to Core 2

The receivers add:

Core 0:  [15 12  9  6] + [ 2  8  6  4]  =  [17 20 15 10]
Core 2:  [ 1  3  4  2] + [12  6  3 15]  =  [13  9  7 17]

State after level 1. Cores 1 and 3 have done their job and drop out. Cores 0 and 2 now each hold a pair-sum:

Core 0:  [17 20 15 10]   = P0 + P1
Core 1:  (done, idle)
Core 2:  [13  9  7 17]   = P2 + P3
Core 3:  (done, idle)

Level 2, distance 2. The two surviving cores combine. Core 2 sends its whole pair-sum to core 0, which adds it in.

Core 2 sends [13 9 7 17] to Core 0
Core 0:  [17 20 15 10] + [13 9 7 17]  =  [30 29 22 27]

Core 0 now holds the complete sum [30, 29, 22, 27]. That took 2 steps, which is log2(4). Each step halved the number of active cores: 4 cores, then 2, then 1. Notice no core ever handled a chunk. Every message in the tree is the whole vector. That is the price the tree pays, and we return to it.

Why log2(N) and not N-1? Because the tree collapses pairs in parallel. The ring forces a partial to walk past every core one at a time, N-1 hops. The tree halves the live set each level, so it reaches the root in the number of times you can halve N, which is log2(N). For N=4 the gap is small, 2 versus 3. For N=256 it is 8 versus 255, and that is the entire reason the tree exists.

Phase two: broadcast-down, the second log2(N)

The root holds the answer. The postcondition needs it on all four cores. Reverse the tree. The root unfolds the result outward, doubling the live set each level, the mirror image of the fold.

Level 1, distance 2. Core 0 sends the whole result to core 2, the other sub-root.

Core 0 sends [30 29 22 27] to Core 2
Core 0:  [30 29 22 27]
Core 2:  [30 29 22 27]

Level 2, distance 1. Each sub-root sends to its child. Both happen in the same step.

Core 0 sends [30 29 22 27] to Core 1
Core 2 sends [30 29 22 27] to Core 3

Final state:

Core 0:  [30 29 22 27]
Core 1:  [30 29 22 27]
Core 2:  [30 29 22 27]
Core 3:  [30 29 22 27]

Every core holds the target. That took 2 steps, again log2(4). The broadcast doubled the informed set each level: 1 core, then 2, then 4. Same shape as the fold, run backward.

Adding the two halves

reduce-up:        log2(N) steps   (fold inward, halve live cores each level)
broadcast-down:   log2(N) steps   (unfold outward, double informed cores each level)
tree all-reduce:  2*log2(N) steps

For N=4 that is 2 + 2 = 4 steps, exactly what we traced. The 2 is the same 2 as the ring: one log2(N) to reduce inward, one log2(N) to broadcast outward. Both algorithms are reduce-then-broadcast. They differ only in how each direction is routed, which sets whether each direction costs N-1 or log2(N).

The ring grows linearly. The tree grows logarithmically. If steps were the only thing that mattered, the tree would always win and the ring would not exist.

The catch — tree is not bandwidth optimal

Steps are not the only thing that matters. Look back at what crossed each link.

In the ring, every message was a chunk of size S/N. Total bytes per link came to 2S(N-1)/N, which flattens to about 2S no matter how many cores you add. That is bandwidth optimal.

In the tree, every message is the whole tensor S. Core 1 shipped its entire vector to core 0. Core 2 shipped its entire pair-sum to core 0. The root then shipped the entire result down. No chunking ever happens. The links near the root carry full-S messages, and that does not shrink as N grows. The tree’s bandwidth term is log₂N x the ring’s, not because each link carries more bytes (it carries about the same, 2S), but because the full-S transfers lie on a serial dependency chain log₂N deep — you stream S through the pipe log₂N times back-to-back instead of streaming S/N chunks that overlap.

That sets up the real tradeoff. Model the time of one step as a fixed latency alpha plus the message transfer S_msg / B, where B is link bandwidth:

ring time  ~  2(N-1) * (alpha + (S/N)/B)
tree time  ~  2 log2(N) * (alpha +  S/B)

Two regimes fall out.

When S is small, the transfer term is tiny and alpha dominates each step. Cost is basically step count times alpha. Fewer steps wins, so the tree wins. This is the latency-bound regime: small tensors, many cores.

When S is large, the transfer term dominates and the message size matters more than the step count. The ring’s S/N messages beat the tree’s full-S messages, so the ring wins. This is the bandwidth-bound regime: big activation tensors in LLM tensor parallelism, where the ring is the default for good reason.

Plugging in NeuronLink-like numbers, 1 microsecond per hop and 190 GB/s, the crossover for 8 cores sits somewhere around a few hundred kilobytes to a megabyte. Below it the tree is faster, above it the ring takes over. That crossover, not a universal winner, is why a real collective library carries both algorithms and picks based on tensor size and core count at compile time.

The one sentence to keep

Tree all-reduce is reduce-up plus broadcast-down, each costing log2(N) because the tree halves the live cores every level instead of walking through them one at a time, so the cost is 2 log2(N). It beats the ring’s 2(N-1) on step count for small tensors where latency dominates, and loses on bandwidth for large tensors because every message is the whole tensor instead of a 1/N chunk. Same reduce-then-broadcast skeleton as the ring. Only the routing, and therefore the cost per direction, changes.