Tiny Cutlass Cute: from tiling to layout algebra a hacker's delight

The Loser Master

A Simple Single-level tiling

It start with a simple copy.

Let say copy from src to dst. A simple fully copy without any slicing.


// src: [0, 1, 2, 3]
// dst: [0, 1, 2, 3]
for (int i = 0; i < dst.size(); i++) {
  dst[i] = src[i];
}
          

A more complicate 2D fully copy look like this.


// shape = (2, 4)
// src: [
//   [0, 1, 2, 3]
//   [4, 5, 6, 7]
// ]
//
// shape = (2, 4)
// dst: [
//   [0, 1, 2, 3]
//   [4, 5, 6, 7]
// ]
for (int i = 0; i < dst.size(); i++) {
  dst[i] = src[i];
}
          

Or we can use a coordinate manner.


// shape = (2, 4)
// src: [
//   [0, 1, 2, 3]
//   [4, 5, 6, 7]
// ]
//
// shape = (2, 4)
// dst: [
//   [0, 1, 2, 3]
//   [4, 5, 6, 7]
// ]
for (int r = 0; r < dst.size<0>(); r++) {
  for (int c = 0; c < dst.size<1>(); c++) {
    dst[r][c] = src[r][c];
  }
}
          

Things become more complicate if we want to copy a slice. And now the stride of the underlying physical layout works.
1. Firstly we want to copy a slice of shape(2, 2)
2. The physical layout is (2, 4)
3. Taking stride carefully


// shape = (2, 4)
// src: [
//   [0, 1, 2, 3]
//   [4, 5, 6, 7]
// ]
//
// shape = (2, 4)
// dst: [
//   [0, 1, 2, 3]
//   [4, 5, 6, 7]
// ]
//
// NOTE: we want to copy a slice:
// [
//   [0, 1]
//   [4, 5]
// ]
auto phy_shape = {2, 4};
auto phy_stride = {4, 1};
for (int r = 0; r < dst.size<0>(); r++) {
  for (int c = 0; c < dst.size<1>(); c++) {
    // coordinate to index: idx = crd2idx(m, n)
    auto i = r * phy_stride[0] + c * phy_stride[1];
    dst[i] = src[i];
  }
}
          

What if the slice start with a offset? Let say the [2, 3, 6, 7].
The stride still work.


// shape = (2, 4)
// src: [
//   [0, 1, 2, 3]
//   [4, 5, 6, 7]
// ]
//
// shape = (2, 4)
// dst: [
//   [0, 1, 2, 3]
//   [4, 5, 6, 7]
// ]
//
// NOTE: we want to copy a slice:
// [
//   [2, 3]
//   [6, 7]
// ]
//
// NOTE: The stride still work as 4!
// [
//   [offset]
//   [2, 3, x, x]
//   [6, 7, ?, ?]
// ]
auto phy_shape = {2, 4};
auto phy_stride = {4, 1};
dst = dst + 2;
src = src + 2;
for (int r = 0; r < dst.size<0>(); r++) {
  for (int c = 0; c < dst.size<1>(); c++) {
    // coordinate to index: idx = crd2idx(m, n)
    auto i = r * phy_stride[0] + c * phy_stride[1];
    dst[i] = src[i];
  }
}
          

A tiling is like create a slice of global memory or share memory. And prepare(copy) data.
- g2g
- g2s
- s2r
- ...

The Multi-level tiling

Think about something like we have a slice of global memory gAgA and a copy of this slice sAsA.

The first gA means the "logical layout": a global logical layout of A. The second gA means the "physical layout and backend": physical layout of A, backend by global memory.

The thing is: 1) We want to copy a slice from global memory to share memory. 2) We want to copy a slice by warp group from share memory: the warps layout. 3) We then copy from a slice of warp's layout to thread's registers.

And to make things easy. We introduce Tensor, a structure who take care of shape, stride, crd2idx idx2crd(logical index to logical coordinate), and the base offset(e.g. physical idx + base_offset).


struct Tensor {
  void* data_ptr;
  size_t base_offset_;
  Shape shape_;
  Stride stride_;
  // logical coordinate to **physical** index.
  size_t crd2idx();
  // logical index to **logical** coordinate.
  Coord idx2crd();
};
          

Then the copy looks like this.


// shape = (2, 4)
// global_layout: [
//   [0, 1, 2, 3]
//   [4, 5, 6, 7]
// ]
//
// shape = (2, 2)
// global_local_slice: [
//   [x, x, 2, 3]
//   [x, x, 6, 7]
// ]
//
// shape = (2, 2)
// smem_layout: [
//   [2, 3]
//   [6, 7]
// ]
//
// shape = (1, 2)
// warps_layout: [
//   [2, 3]
// ]
//
// shape = (1, 1)
// threads_layout: [
//   [3]
// ]

// name = gmem layout of A, backend by gmem
// shape:stride = (2, 4):(4, 1)
// base offset = 0
Tensor gA;
// name = local gmem layout of A, backend by gmem
// shape:stride = (2, 2):(4, 1), base
// base offset = xxx
Tensor gAgA;
// name = smem layout of A, backend by smem
// shape:stride = (2, 2):(2, 1)
// base offset = 0
Tensor sAsA;
// name = warps layout of A, backend by smem
// shape:stride = (1, 2):(2, 1)
// base offset = xxx
Tensor wAsA;
// name = threads layout of A, backend by reg
// shape:stride = (1, 1):(1, 1)
// base offset = xxx
Tensor tArA;

// global to shared memory
auto src = gAgA;
auto dst = sAsA;
for (int i = 0; i < dst.size(); i++) {
  // get logical coordinate
  auto l_coord_src = src.idx2crd(i);
  auto l_coord_dst = dst.idx2crd(i);
  // get physical coordinate
  auto src_idx = src.crd2idx(l_coord_src);
  auto dst_idx = dst.crd2idx(l_coord_dst);
  dst[dst_idx] = src[src_idx];
}

// shared memory to warp tile
auto src = sAsA;
auto dst = wAsA;
for (int i = 0; i < dst.size(); i++) {
  // get logical coordinate
  auto l_coord_src = src.idx2crd(i);
  auto l_coord_dst = dst.idx2crd(i);
  // get physical coordinate
  auto src_idx = src.crd2idx(l_coord_src);
  auto dst_idx = dst.crd2idx(l_coord_dst);
  dst[dst_idx] = src[src_idx];
}

// warp tile to thread's reg
auto src = wAsA;
auto dst = tArA;
for (int i = 0; i < dst.size(); i++) {
  // get logical coordinate
  auto l_coord_src = src.idx2crd(i);
  auto l_coord_dst = dst.idx2crd(i);
  // get physical coordinate
  auto src_idx = src.crd2idx(l_coord_src);
  auto dst_idx = dst.crd2idx(l_coord_dst);
  dst[dst_idx] = src[src_idx];
}
          

Now I am happy to get rid of something like: ptr + base_offset + stride0 * a + stride1 * b + ...

The physical, the logical, and the hierarchical

To be continue.

Hacker's delight: Let's play with bits

To be continue.

The Layout Algebra

To be continue.