The previous blog post talked about Triton linear layout concepts, aiming to provide some underlying motivations and an intuitive understanding. As a companion, in this one I’d like to touch on linear layout internals and follow up with some concrete examples to show its usage in action and make it even more comprehensible. Following the same vein, common languages and explanations are preferred instead of mathematical terms and interpretations.
Basics
Let’s start with a recap of some representation fundamentals and then introduce some key operations which we will need to utilize later.
Data structures
As mentioned in the “How to represent” section of the previous blog post, the
LinearLayout C++ class uses the following data structures to record the index
mapping system:
// bases[inDim][i] = L(0, ..., inDim=2^i, ..., 0). All other values of L are
// computed by xor'ing bases together, using the linearity rule. In addition:
//
// - Each inDim has the same set of outDims, in the same order.
// - The order of dims is minor-to-major, although this only affects reshape.
llvm::MapVector<StringAttr /*inDim*/,
std::vector<std::vector<int32_t> /*size=getNumOutDims()*/>
/*size=getInDimSizeLog2(inDim)*/>
bases;
llvm::MapVector<StringAttr, int32_t /*size*/> outDims;
int32_t rank = 0;
Input and output dimensions
Inside bases, each input dimension has its own map entry containing a readable name and a vector
of output dimension strides.
For bases[inDim], the mapped vector, called “basis vectors”, records how each
inDim bit, from least to most significant, contributes the strides along all output
dimensions.
Although the input dimension’s name is an unconstrained general string instead of some enumerant
symbol, conventionally we typically use "register", "lane", "warp", "block" when describing
hardware compute hierarchy location and "offset" when describing shared memory allocation.
Triton codebase has some comments here regarding it and implicitly relies on this
naming convention in various places.
Output dimension names are not recorded in the bases data structure; it’s captured as a separate
outDims MapVector as shown in the above.
Recall that these dimensions are ordinarily for describing n-D logical tensor; conventionally it’s
just "dimN" like "dim0", "dim1", and so on, which is what the
standardOutDimNames() utility returns.
Primitive index mappings
LinearLayout class provides some static methods for creating basic index mappings:
LinearLayout::empty(): a 0-D layout that maps everything to 0. In a sense you can think of it as a single “point” in an indexing system. It is typically the initial point for building up more realistic layouts.LinearLayout::identity1D(size, inDim, outDim): a 1-D identity layout from mapping an indexiininDimto the sameiinoutDim. Using a similar analogy this is the linear “line” for an indexing system. We frequently use it as a basic unit when building layouts. For example, to describe that we want to map 4 consecutive elements along"dimN"in the logical tensor to 4 consecutive registers, it would beidentity1D(4, "register", "dimN").LinearLayout::zeros1D(size, inDim, outDim): a 1-D layout that maps everyinDimindexitooutDimindex0. This can represent multiple hardware locations mapping to the same logical tensor element; or in the reverse way to say it, broadcasting the same logical tensor element to multiple hardware locations.LinearLayout::strided1D(size, stride, inDim, outDim): a 1-D layout that maps everyinDimindexitooutDimindexstride * i.
Note that the above primitive index mappings all require size.
This is one of the fundamental requirements of linear layout—it’s bound to a known static shape,
comparing to traditional bespoke layouts like blocked/shared/mma layouts which do not encode
tensor shapes in them.
When converting those bespoke layouts into linear layout, we need to provide the tensor shape
we want to bind to.
Operations
With the above basics and 1-D primitives, we can now start building useful n-D layouts. This can be achieved with a few key linear layout operations that generalize and scale well and give great powers.
Product
The LinearLayout::operator*() method implements products.
This method is frequently used for building real-world layouts where we map whole hardware
hierarchy to n-D logical tensor elements.
The code documentation, which contains some examples, is well worth reading.
An intuitive interpretation of product is that we are building larger “space” with smaller
“subspaces” as “primitives.”
For example, let’s say we want to describe a mapping that, along "dimN" of the tensor, we have
a consecutive 16 GPU threads and each thread owns 4 consecutive elements in its
registers.
We can define laneLL = LinearLayout::identity1D(16, "lane", "dimN") for the first part,
and registerLL = LinearLayout::identity1D(4, "register", "dimN") for the second part and do
registerLL * laneLL to describe the overall mapping.
It treats every 4 register elements as a subspace in the larger space;
if we “collapse” that subspace as a single point, it’s clear that the "lane" to "dimN" index
mapping is identity which reflects the laneLL construct.
Note that the multiplication order matters! Using examples from the documentation,
identity1D(4, "i", "o") * zeros1D(2, "i", "o") gives a layout of L(x) = x % 4.
Similarly using the above interpretation, after collapsing a subspace of 4 consecutive elements,
all indices map to 0; therefore it’s L(x) = x % 4.
If we do zeros1D(2, "i", "o") * identity1D(4, "i", "o"), then it has a primitive of 2
consecutive input index mapping to 0 and then using that as a subspace to build a consecutive
linear indexing; therefore it should be L(x) = x / 2.
Composition
The LinearLayout::compose() method implements composition.
Composition is like nested function calls—it takes the inner linear layout’s output dimensions and
feeds as input dimensions to the outer linear layout.
Composition is not particularly interesting on its own in reality, due to that we normally describe layouts from hardware location to logical tensor elements so we typically use logical tensor as the output dimensions. It’s more interesting together with inversion.
Inversion
The LinearLayout::(pseudo)invert() method(s) implements inversion.
Intuitively as the name suggests, it’s inverting the mapping to get from output to input dimensions.
There are a set of restrictions as whether/how a layout can be inverted, like every index along
the output dimension needs to have at least one input index mapping to it so that we cover the
whole logical tensor.
I’ll omit the math details and you can read the official linear layout paper if
interested in that.
With inversion, we can use logical tensor dimensions as the “bridge” to connect different
hardware locations.
Such mapping is very useful for us to understand and “compute” characteristics when data flow
among different hardware locations.
For that, we have a dedicated API, the LinearLayout::invertAndCompose() method,
for it.
One common such usage is figuring out the element offsets in the shared memory that each GPU
register should read.
We can simply use a distributedLL, which describes how compute hierarchy wants tensor elements,
and a sharedLL, which describes how shared memory allocation holds tensor elements, and perform
distributedLL.invertAndCompose(sharedLL), which will give us the mapping from registers to shared
memory allocation offsets.
(Note that A.invertAndCompose(B) means B^-1(A).)
Broadly, for GPU performance, how we move data across the hardware hierarchy matters a lot.
Questions like whether we are using wide instructions to read/write global memory in coalesced
manner, whether we can exchange data by simply doing register exchange or thread exchange within
the same warp or cross warps using shared memory, and so on, can all sort of be easily answered
by deducing from A.invertAndCompose(B).
But before we go into that, I’d like to explain another pretty useful operation.
Left division
The LinearLayout::divideLeft() method implements it.
Mathematically it’s involved but intuitively you can interpret it as the reverse of product
(reflected in the names as division vs. multiplication).
If product is building up larger spaces with smaller subspaces, divideLeft(A, B) is essentially
asking whether B is a subspace of A and therefore we can obtain a resultant layout by
“collapsing” B.
This operation can be pretty useful to check, for example, whether a data movement can be
implemented using a particular fast hardware intrinsic, as we will see later in examples.
Examples
Okay, we have introduced major building blocks now and we can dive into concrete examples to see linear layouts in action.
N-D identity layout
The identityStandardND() utility is a good starting point which combines some of
the above discussed pieces in straightforward manner:
// Returns a 1D -> ND layout into [dim0, dim1, ...] that's equivalent to
// creating a 1D -> 1D mapping of size product(shape) and then reshaping to
// permute(shape, order).
LinearLayout identityStandardND(StringAttr inDimName, ArrayRef<unsigned> shape,
ArrayRef<unsigned> order) {
assert(shape.size() == order.size());
MLIRContext *ctx = inDimName.getContext();
auto rank = shape.size();
// The order in triton is written wrt. [dim0, dim1, ...].
SmallVector<StringAttr> outDimNames = standardOutDimNames(ctx, rank);
LinearLayout ret = LinearLayout::empty();
for (int i = 0; i < shape.size(); i++) {
// Start with the most-minor dimension, which is order[0].
int dim = order[i];
ret *= LinearLayout::identity1D(shape[dim], inDimName, outDimNames[dim]);
}
return ret;
}
From the definition we are building the final layout by iteratively applying multiplication of what
we have right now with a 1-D mapping from input dimension to next output dimension.
We can intuitively understand with the previous subspace way.
Or, another way to think about it: we fold/merge previous innermost consecutive levels as one
element so that the next level becomes consecutive and therefore we can multiply with next level’s
identity1D.
Dumping an example with llvm::outs() << identityStandardND("register", {2, 4, 8}, {2, 1, 0}):
- register=1 -> (1, 0, 0)
register=2 -> (2, 0, 0)
register=4 -> (4, 0, 0)
register=8 -> (0, 1, 0)
register=16 -> (0, 2, 0)
register=32 -> (0, 0, 1)
where out dims are: [dim2 (size 8), dim1 (size 4), dim0 (size 2)]
Note that, as previously explained, LinearLayout internally records an index mapping value for
each bit of the input dimension and the corresponding strides on all output dimensions.
So printing in decimal format reveals as all values are some power of twos.
With the above prints we can see that when we consecutively increase index along "register" input
dimension, we first consecutively increase along "dim2", and then "dim1", and then "dim0"
as we would expect for N-D identity mapping with the given innermost to outermost order.
MFMA layout
For the MFMA layout mentioned in the previous blog post, we would like a 2-D layout,
where we first map 4 consecutive elements along M ("dim0") output dimension to consecutive
registers, and then we map 16 consecutive elements along N ("dim1") output dimension to
consecutive lanes.
Folding/merging the above into one “element”, we can see the next level we map 4 consecutive
elements along M output dimension to (64 / 16) consecutive lanes.
Therefore, what we need would be
LinearLayout::identity1D(4, "register", dimM) *
LinearLayout::identity1D(16, "lane", dimN) *
LinearLayout::identity1D(64 / 16, "lane", dimM)
The above is a simplified version of what the AMDMfmaEncodingAttr::toLinearLayout()
method effectively does.
The full implementation is more complicated given it needs to support different MFMA intrinsics
and handle multiple warp layouts, while the above only talks about a single warp tile.
Multiple warps and even multiple blocks are simply nesting more levels on top of the above and won’t
disturb what’s described above though; so hopefully it should be easy to understand.
For #ttg.amd_mfma<{version = 3, warpsPerCTA = [2, 2], instrShape = [16, 16, 16], isTransposed = false}>,
dumping its layout on a tensor<32x64xf32>:
- register=1 -> (1, 0)
register=2 -> (2, 0)
register=4 -> (0, 32)
- lane=1 -> (0, 1)
lane=2 -> (0, 2)
lane=4 -> (0, 4)
lane=8 -> (0, 8)
lane=16 -> (4, 0)
lane=32 -> (8, 0)
- warp=1 -> (0, 16)
warp=2 -> (16, 0)
- block is a size 1 dimension
where out dims are: [dim0 (size 32), dim1 (size 64)]
I won’t repeat all the details in the above, just pointing out some key points:
as expected, the first 16 threads hold consecutive elements along N ("dim1"), and the next
16 threads start with index 4 along M ("dim0").
And if we fold/merge every 16x16 elements (which is what a warp tile is) into one, we can
see warps are consecutively increasing along N and then M, which composes as the next level of
nesting.
There are 2 warps along each dimension, thus matching the warpsPerCTA = [2, 2].
Distributed to shared layout
In the Inversion section we briefly touched about using
cvt = distributedLL.invertAndCompose(sharedLL) to get the mapping from registers to shared memory
allocation offsets and leveraging its characteristics to optimize lowering.
One cool example is figuring the largest vectorization we can perform when lowering shared memory
load/store, which is implemented as the largestVectorization() utility.
The idea is that we start from the largest possible bitwidth allowed by hardware and test whether
it’s possible.
If not, we halve the bitwidth and try again. We use a loop to continue until getting one possible.
How do we perform the check? If we think one individual hardware instruction as the primitive,
described as tile = LinearLayout::identity1D(vectorSize, "register", "offset"),
then the check is about whether we can break the whole convert cvt into primitive tile; that is,
whether we can perfectly “divide” cvt with tile.
That comes to what we have explained earlier—divideLeft can perform such division computation!
With divideLeft(cvt, tile), we can implement largestVectorization() in a pretty clean and simple
manner.
Generic layout conversion
Moving onto the next example.
ttg.convert_layout is one of the key operations in the Triton GPU dialect; it bridges one layout
to another during compiler conversion flows.
Such layout conversions potentially require data movement across the hardware hierarchy so they can be
expensive.
In Triton we perform the TritonGPURemoveLayoutConversionsPass to optimize away them as much
as possible (and linear layout helps to prove some layout conversions are trivial), though there
are still cases where we cannot completely eliminate one and need to materialize a conversion.
For such cases, we want to optimize the conversion—for example, if data exchange is among threads in the same warp then we can perform warp shuffling which would be much cheaper than going through shared memory. If data exchange is among different warps then we are forced to use shared memory.
To check what path we need to go down when converting from srcLayout to dstLayout, what we need
to do is just cvt = dstLayout.invertAndCompose(srcLayout) (so that we get a mapping from
destination to source layout indices), and then check whether cvt’s "warp" dimension is just an
identity to itself and involves no other dimensions at all.
If so then we have no cross-warp data exchange, otherwise we have.
Similarly we can check other dimensions like "block" and "lane" to deduce whether data exchange
is across or within for them too, which is quite elegant.
This general approach is concretely implemented in the ttg.convert_layout’s LLVM conversion
pattern, where we first get minimalCvtLayout(srcTy, dstTy), and then go down
the list to check whether it involves "block", "warp", and "lane" level data exchange.
The minimalCvtLayout() utility is a pretty commonly used utility for deducing the
“smallest submap of srcTy^{-1} * dstTy that is not the identity under the common dimensions.”
Internally, it just calls dstLayout.invertAndCompose(srcLayout) and performs quotient()
operation to drop identity dimensions from the slowest to fastest.
Final Words
This blog post supplements the previous one and the linear layout paper with discussions of linear layout representation internals and operations. We then go over some concrete examples to see how linear layouts are built up and utilized for optimizing code generation in action. Linear layout right now replaces various bespoke conversions we had in Triton; it provides great simplicity and generalization, once you fully grasp the ideas. Hopefully this blog post helps along that line.