How to multiply tensors wihout doing so? (3/3)

tinygrad
pytorch
tensors
Author

Andre Mirończuk

Published

September 27, 2024

First part
Second part

Let’s take tinygrad as an example.

Here’s the high-level code for matmul:

def dot(self, w:Tensor, acc_dtype:Optional[DTypeLike]=None) -> Tensor:
    n1, n2 = len(self.shape), len(w.shape)
    assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D"
    if (L:=self.shape[-1]) != (R:=w.shape[-min(n2, 2)]): raise AssertionError(f"shapes {self.shape} and {w.shape} cannot be multiplied ({L} != {R})")
    x = self.reshape(*self.shape[0:-1], *[1]*min(n1-1, n2-1, 1), self.shape[-1])
    w = w.reshape(*w.shape[0:-2], *[1]*min(n1-1, n2-1, 1), *w.shape[-min(n2, 2):]).transpose(-1, -min(n2, 2))
    return (x*w).sum(-1, acc_dtype=acc_dtype).cast(least_upper_dtype(x.dtype, w.dtype) if acc_dtype is None else acc_dtype)

def matmul(self, x:Tensor, reverse=False, acc_dtype:Optional[DTypeLike]=None) -> Tensor:
    return x.dot(self, acc_dtype=acc_dtype) if reverse else self.dot(x, acc_dtype=acc_dtype)

It might look intimidating, but it’s actually really elegant and beautiful.
You’ll see.

First lines are simple.

We will use these tensors as examples (PyTorch has very similar api, so we will use it for simplicity):

import torch
self = torch.rand(3, 3, 4)
w = torch.rand(2, 1, 4, 5)
self.shape, w.shape
(torch.Size([3, 3, 4]), torch.Size([2, 1, 4, 5]))

We extract the rank of both tensors.

n1, n2 = len(self.shape), len(w.shape)
n1, n2
(3, 4)

Now we do some checks to ensure the operation we want to perform is valid:

assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D"
if (L:=self.shape[-1]) != (R:=w.shape[-min(n2, 2)]): raise AssertionError(f"shapes {self.shape} and {w.shape} cannot be multiplied ({L} != {R})")

The first line is pretty self-explanatory.

The second one will establish if you can indeed matmul those tensors.

-min(n2, 2) will evaluate to either -1 or -2. Remember, n2 cannot be 0 at this point. It will compare the last dimension of self with the second to last dimension of w, unless w is a vector (1-dimensional tensor, e.g., size = (3,), but not a row vector (e.g., size = (1, 3))). Then, since there is only one dimension, we will extract the last one.

Will that take care of all possible shape permutations between tensors?
Not quite. It will handle the most important part and turn a blind eye to all batch dimensions. We’ll touch on that later.

For two 2-by-2 tensors, it’s easy to tell if you can multiply them.

What about those (shapes)?:
(2, 1, 6, 4) and
(4, 3, 1, 5, 4, 7)

Can you?

Absolutely.

tmp1 = torch.rand(2, 1, 6, 4)
tmp2 = torch.rand(4, 3, 1, 5, 4, 7)
(tmp1@tmp2).shape
torch.Size([4, 3, 2, 5, 6, 7])

With all batch dimensions, you only compare them to their counterparts. I wrote about broadcasting rules in the first part.

What you sort of end up “matmuling” are the two last dimensions of both tensors:
(6, 4) and
(4, 7)

Out of it comes (6, 7). The 4s disappear, and you get the shape (4, 3, 2, 5, 6, 7).

Second line of this snippet checked if those 4s were the same.

Great! Wait, but what about mismatching batch dimensions?

These tensors (shapes) cannot be multiplied:
(2, 6, 4) and
(3, 4, 7)

try:
    torch.rand(2, 6, 4) @ torch.rand(3, 4, 7)
except Exception as e:
    print(e)
The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 0

In this situation, this error will be thrown in the first part of the return expression ➫ (x*y). Namely, broadcasting will fail.

Next, we will reshape our tensors just like before.

x = self.reshape(*self.shape[0:-1], *[1]*min(n1-1, n2-1, 1), self.shape[-1])

Both tensors will be reshaped so that they become a cube after being *ed.

This operation will put the correct number of singleton dimensions in second to last place.

Easy stuff first:

print(*self.shape[0:-1])
print(self.shape[-1])
3 3
4

Moving on.

Now, this is beautiful Python:

*[1]*min(n1-1, n2-1, 1)

min(n1 - 1, n2 - 1, 1) will evaluate to either 1 or 0. It just checks if either of the tensors is 1-dimensional (e.g., shape = (3,)). We have already seen this type of operation from before ➫ -min(n2, 2).

*[1]*

What is this?

If min(n1 - 1, n2 - 1, 1) evaluates to 1, [1] will not change. Then, using *, it will be unpacked to just a 1. So exactly the output of min(n1 - 1, n2 - 1, 1).

print(min(n1-1, n2-1, 1))
print(*[1] * min(3 - 1, 4 - 1, 1))
1
1

So:

self.reshape(3, 3, 1, 4)

However if min(n1 - 1, n2 - 1, 1) evaluates to 0, then we don’t want to put anything there (x = self.reshape(3, 3, 4) and not x = self.reshape(3, 3, 0, 4) or x = self.reshape(3, 3, —, 4)).

To achieve that, we first multiply the result with [1]. For 1, it will not change anything, as we stated already. But for 0, instead of [0], we will actually get just an empty list []. And when you unpack an empty list, it will disappear.

print("♘", *[1] * min(3 - 1, 1 - 1, 1), "♘")
print("♘", "♘")
♘ ♘
♘ ♘

So:

self.reshape(3, 3, 4)

Instead of:

self.reshape(3, 3, 0, 4)

Genius.

Let’s peek at the shape.

x.shape
torch.Size([3, 3, 1, 4])

Just as we discussed.

We’ll do something similar for the second tensor:

w = w.reshape(
    *w.shape[0:-2],
    *[1] * min(n1 - 1, n2 - 1, 1),
    *w.shape[-min(n2, 2):]
).transpose(-1, -min(n2, 2))
print(*w.shape[0:-2], end="   ")
print(*[1]*min(n1 - 1, n2 - 1, 1), end="   ")
print(*w.shape[-min(n2, 2):])
2 1   1   4 5

*w.shape[0:-2] extracts all dimensions not including the last two (could be []).
*[1]*min(n1-1, n2-1, 1) does the same thing as before.
*w.shape[-min(n2, 2):] extracts all the left-out dimensions (either the last two or the last one).

So:

w.reshape(2, 1, 1, 4, 5)
w = w.reshape(*w.shape[0:-2], *[1]*min(n1-1, n2-1, 1), *w.shape[-min(n2, 2):])

Let’s peak at the shape (before transposing).

w.shape
torch.Size([2, 1, 1, 4, 5])

We’ll transpose it for the same reasons we stated in previous parts.
Again -min(n2, 2) will check if we’re working with 1-dimensional tensors.

.transpose(-1, -2) switches the two last dimensions.
.transpose(-1, -1) does nothing.

Examples:

torch.rand(2, 3, 4).transpose(-1, -2).shape
torch.Size([2, 4, 3])
torch.rand(2, 3, 4).transpose(-1, -1).shape
torch.Size([2, 3, 4])
try:
    torch.rand(3).transpose(-1, -2).size
except Exception as e:
    print(e)
Dimension out of range (expected to be in range of [-1, 0], but got -2)

Let’s finish transposing.

w = w.transpose(-1, -min(n2, 2))
w.shape
torch.Size([2, 1, 1, 5, 4])

Lastly we’ll multiply them and sum over the last dimension, skipping custom logic regarding library’s dtypes.

return (x * w)
        .sum(-1, acc_dtype=acc_dtype)
        .cast(
            least_upper_dtype(x.dtype, w.dtype) 
            if acc_dtype is None 
            else acc_dtype
        )
x.shape, w.shape
(torch.Size([3, 3, 1, 4]), torch.Size([2, 1, 1, 5, 4]))

Now, broadcast into a higher dimension over the created singleton dimensions (for two 2-dimensional starting tensors, this would result in a cube), then reduce it back to the original dimension using sum(), which squeezes out the singleton dimensions by default.

(x*w).shape
torch.Size([2, 3, 3, 5, 4])
res = (x*w).sum(-1)

And…

res
tensor([[[[0.5259, 0.3000, 0.6004, 0.2894, 0.4077],
          [0.8279, 0.7619, 0.8626, 0.4019, 0.8703],
          [1.1576, 0.7582, 1.1252, 0.3583, 0.8744]],

         [[1.8753, 1.0986, 1.9939, 0.7821, 1.4981],
          [1.6079, 0.9101, 1.8997, 0.9157, 1.5089],
          [1.4465, 0.8693, 1.6667, 0.8347, 1.1599]],

         [[1.4186, 1.1526, 1.4041, 0.5483, 1.2553],
          [1.9908, 1.2367, 2.1596, 0.8922, 1.7607],
          [1.0949, 0.5109, 1.1055, 0.3661, 0.5976]]],


        [[[0.3232, 0.3395, 0.2736, 0.4738, 0.5532],
          [0.7879, 0.4619, 0.4811, 0.7820, 0.8044],
          [0.7571, 0.8372, 0.8546, 1.0064, 0.7694]],

         [[1.0566, 1.3344, 1.3176, 1.7892, 1.6994],
          [0.8136, 1.0792, 1.0664, 1.7680, 2.0364],
          [0.9499, 0.8993, 0.7038, 1.2998, 1.5714]],

         [[1.2052, 0.8936, 0.8995, 1.2275, 1.0797],
          [1.1456, 1.3835, 1.4295, 2.0276, 2.0094],
          [0.5575, 0.8357, 0.7197, 0.8475, 0.6852]]]])
res.shape
torch.Size([2, 3, 3, 5])

Here it is. Nice and warm.

It went through a very similar route as our recruits from before and came out just as it should.

Numpy is tinygrad’s backend. So underneath, there are C arrays.

What about PyTorch, you say? We’ll drown in ATen maybe next time.
Bring a mace.