mutiple

https://blog.csdn.net/da_kao_la/article/details/87484403

#matmul
tensor有四种乘法

1. *

对于a = b * c来说,如果b和c的size不一样,则会以某种方式,比如expand将两者shape对齐进行element-wise的简单乘法

import torch
b = torch.ones(3,4)
c = 2
print(b * c)

# 行向量expand
c = torch.tensor([1,2,3,4])
print(b * c)

# 列向量expand
c = torch.tensor([1,2,3]).reshape((3,1))
print(b * c)

# 都是矩阵的话要求两者shape一致
b = torch.tensor([[1,2],[2,3]])
print(b * b)

2. torch.mul

torch.mul*的用法相同,也是element-wise的,支持broadcastingbroadcasting semantics

a = torch.ones(3,4)
b = torch.tensor([1,2,3]).reshape((3,1))

c = torch.mul(a,b)
print(c)

3. torch.mm

终于到了我们数学上的矩阵乘法了,这里限制两个tensor的shape需要满足矩阵乘法的要求


a = torch.ones(3,4)
b = torch.ones(4,3)
c = torch.mm(a, b)
print(c)

4. torch.matmul

torch.matmultorch.mm的支持broadcast版本

a = torch.ones(3,4)
# broadcast从trailing dimension对齐,这里b的后两维可以和a做矩阵乘
b = torch.ones(5,4,2)

c = torch.matmul(a,b)
print(c)