It start with a simple copy.
Let say copy from src to dst. A simple fully copy without any slicing.
// src: [0, 1, 2, 3]
// dst: [0, 1, 2, 3]
for (int i = 0; i < dst.size(); i++) {
dst[i] = src[i];
}
A more complicate 2D fully copy look like this.
// shape = (2, 4)
// src: [
// [0, 1, 2, 3]
// [4, 5, 6, 7]
// ]
//
// shape = (2, 4)
// dst: [
// [0, 1, 2, 3]
// [4, 5, 6, 7]
// ]
for (int i = 0; i < dst.size(); i++) {
dst[i] = src[i];
}
Or we can use a coordinate manner.
// shape = (2, 4)
// src: [
// [0, 1, 2, 3]
// [4, 5, 6, 7]
// ]
//
// shape = (2, 4)
// dst: [
// [0, 1, 2, 3]
// [4, 5, 6, 7]
// ]
for (int r = 0; r < dst.size<0>(); r++) {
for (int c = 0; c < dst.size<1>(); c++) {
dst[r][c] = src[r][c];
}
}
Things become more complicate if we want to copy a slice. And now the stride of the underlying physical layout works.
1. Firstly we want to copy a slice of shape(2, 2)
2. The physical layout is (2, 4)
3. Taking stride carefully
// shape = (2, 4)
// src: [
// [0, 1, 2, 3]
// [4, 5, 6, 7]
// ]
//
// shape = (2, 4)
// dst: [
// [0, 1, 2, 3]
// [4, 5, 6, 7]
// ]
//
// NOTE: we want to copy a slice:
// [
// [0, 1]
// [4, 5]
// ]
auto phy_shape = {2, 4};
auto phy_stride = {4, 1};
for (int r = 0; r < dst.size<0>(); r++) {
for (int c = 0; c < dst.size<1>(); c++) {
// coordinate to index: idx = crd2idx(m, n)
auto i = r * phy_stride[0] + c * phy_stride[1];
dst[i] = src[i];
}
}
What if the slice start with a offset? Let say the [2, 3, 6, 7].
The stride still work.
// shape = (2, 4)
// src: [
// [0, 1, 2, 3]
// [4, 5, 6, 7]
// ]
//
// shape = (2, 4)
// dst: [
// [0, 1, 2, 3]
// [4, 5, 6, 7]
// ]
//
// NOTE: we want to copy a slice:
// [
// [2, 3]
// [6, 7]
// ]
//
// NOTE: The stride still work as 4!
// [
// [offset]
// [2, 3, x, x]
// [6, 7, ?, ?]
// ]
auto phy_shape = {2, 4};
auto phy_stride = {4, 1};
dst = dst + 2;
src = src + 2;
for (int r = 0; r < dst.size<0>(); r++) {
for (int c = 0; c < dst.size<1>(); c++) {
// coordinate to index: idx = crd2idx(m, n)
auto i = r * phy_stride[0] + c * phy_stride[1];
dst[i] = src[i];
}
}
A tiling is like create a slice of global memory or share memory. And prepare(copy) data.
- g2g
- g2s
- s2r
- ...
Think about something like we have a slice of global memory gAgA and a copy of this slice sAsA.
The first gA means the "logical layout": a global logical layout of A. The second gA means the "physical layout and backend": physical layout of A, backend by global memory.
The thing is: 1) We want to copy a slice from global memory to share memory. 2) We want to copy a slice by warp group from share memory: the warps layout. 3) We then copy from a slice of warp's layout to thread's registers.
And to make things easy. We introduce Tensor, a structure who take care of shape, stride, crd2idx idx2crd(logical index to logical coordinate), and the base offset(e.g. physical idx + base_offset).
struct Tensor {
void* data_ptr;
size_t base_offset_;
Shape shape_;
Stride stride_;
// logical coordinate to **physical** index.
size_t crd2idx();
// logical index to **logical** coordinate.
Coord idx2crd();
};
Then the copy looks like this.
// shape = (2, 4)
// global_layout: [
// [0, 1, 2, 3]
// [4, 5, 6, 7]
// ]
//
// shape = (2, 2)
// global_local_slice: [
// [x, x, 2, 3]
// [x, x, 6, 7]
// ]
//
// shape = (2, 2)
// smem_layout: [
// [2, 3]
// [6, 7]
// ]
//
// shape = (1, 2)
// warps_layout: [
// [2, 3]
// ]
//
// shape = (1, 1)
// threads_layout: [
// [3]
// ]
// name = gmem layout of A, backend by gmem
// shape:stride = (2, 4):(4, 1)
// base offset = 0
Tensor gA;
// name = local gmem layout of A, backend by gmem
// shape:stride = (2, 2):(4, 1), base
// base offset = xxx
Tensor gAgA;
// name = smem layout of A, backend by smem
// shape:stride = (2, 2):(2, 1)
// base offset = 0
Tensor sAsA;
// name = warps layout of A, backend by smem
// shape:stride = (1, 2):(2, 1)
// base offset = xxx
Tensor wAsA;
// name = threads layout of A, backend by reg
// shape:stride = (1, 1):(1, 1)
// base offset = xxx
Tensor tArA;
// global to shared memory
auto src = gAgA;
auto dst = sAsA;
for (int i = 0; i < dst.size(); i++) {
// get logical coordinate
auto l_coord_src = src.idx2crd(i);
auto l_coord_dst = dst.idx2crd(i);
// get physical coordinate
auto src_idx = src.crd2idx(l_coord_src);
auto dst_idx = dst.crd2idx(l_coord_dst);
dst[dst_idx] = src[src_idx];
}
// shared memory to warp tile
auto src = sAsA;
auto dst = wAsA;
for (int i = 0; i < dst.size(); i++) {
// get logical coordinate
auto l_coord_src = src.idx2crd(i);
auto l_coord_dst = dst.idx2crd(i);
// get physical coordinate
auto src_idx = src.crd2idx(l_coord_src);
auto dst_idx = dst.crd2idx(l_coord_dst);
dst[dst_idx] = src[src_idx];
}
// warp tile to thread's reg
auto src = wAsA;
auto dst = tArA;
for (int i = 0; i < dst.size(); i++) {
// get logical coordinate
auto l_coord_src = src.idx2crd(i);
auto l_coord_dst = dst.idx2crd(i);
// get physical coordinate
auto src_idx = src.crd2idx(l_coord_src);
auto dst_idx = dst.crd2idx(l_coord_dst);
dst[dst_idx] = src[src_idx];
}
Now I am happy to get rid of something like: ptr + base_offset + stride0 * a + stride1 * b + ...