linear

#linear

https://pytorch.org/docs/stable/generated/torch.nn.Linear.html
Linear的操作实际是针对输入的最后一维进行一个长度转换,转换操作由内部的weight和bias和输入的tensor进行矩阵乘操作来实现

Applies an affine linear transformation to the incoming data:

y=xAT+b

其中A和b分别是Linear中的weight和bias,这两个值会根据定义linear的in_features和out_features时确定shape,之后随机初始化,当然后也可以指定初始化算法,之后随模型训练更新weight

注意

这里的Linear的in_features和out_features只是匹配最后一个维度的长度,其他维度保持不变

import torch
from torch import nn
m = nn.Linear(3,4)
print(m.weight)
print(m.bias)

input = torch.randn(3,3)
output1 = m(input)

output2 = torch.matmul(input, m.weight.t()) + m.bias

print("output1:\n", output1)
print("output2:\n", output2)

assert torch.equal(output1, output2)