Pytorch를 이용해 짠 대부분의 소스코드들에는 nn.Linear() 함수가 거의 꼭 들어가 있는 것을 알 수 있을 것이다.
그만큼 자주 많이 사용되고 크게 어려운 개념은 아니지만, 그렇다고 대충 넘어가면 추후 코드를 이해함에 있어서 문제가 생길 수 있다.
따라서 nn.Linear()에 대해 차근차근 파헤쳐보자.
nn.Linear()란?
Pytorch에서 선형회귀 모델은 nn.Linear() 함수에 구현되어 있다.
Pytorch 공식 문서에서 설명하는 nn.Linear() 함수는 다음과 같다.
입력 인자로 in_features와 out_features를 받고 이에 맞춰 반환한다.
예를 들어 내가 768차원 짜리를 64차원으로 만들고 싶으면
nn.Linear(768,64)
위와 같은 형태로 작성하면 된다.
단순히 Linear Transformation(선형 변환)을 진행해준다고 보면 된다.
예를 들어 아래와 같이 코드를 작성한다고 해보자.
3x2 행렬을 입력으로 사용하고, input_features와 output_features는 각각 2로 설정한다.
정리하면 3x2행렬을 입력으로 넣어서 3x2행렬을 얻는 것이다.
import torch
import torch.nn as nn
x = torch.tensor([[1.0, -1.0],
[0.0, 1.0],
[0.0, 0.0]])
in_features = x.shape[1] # = 2
out_features = 2
m = nn.Linear(in_features, out_features)
m이라는 변수에 nn.Linear()의 값이 할당되는 것을 알 수 있을 것이다.
그렇다면 이 m에는 어떤 것이 담겨 있는 것일까?
여기서 다시 앞서 봤던 Pytorch 공식 문서를 한번 더 들여다보자.
마지막의 Variables를 보면 weight와 bias가 존재하는 것을 알 수 있을 것이다.
따라서 m의 weight와 bias를 찍어보면 다음과 같이 나온다.
>>> m.weight
tensor([[-0.4500, 0.5856],
[-0.1807, -0.4963]])
>>> m.bias
tensor([ 0.2223, -0.6114])
이를 통해 우리는 nn.Linear()를 통해 생성된 값이 2x2의 weight matirx와 2x1의 bias matrix를 가지고 있다라는 것을 알 수 있다.
따라서 x를 m에 입력으로 사용하면 다음과 같다.
>>> y = m(x)
tensor([[-0.8133, -0.2959],
[ 0.8079, -1.1077],
[ 0.2223, -0.6114]])
이 내부의 연산과정을 코드로 표현하면 다음과 같다.
y = x.matmul(m.weight.t()) + m.bias # y = x*W^T + b
결국 nn.Linear() 함수를 정리하자면
위 식의 W와 b를 담고 있는 함수를 반환한다 라고 생각하면 된다.
따라서 이 nn.Linear()를 하나의 객체에 담고 이 객체에 input을 넣어주면 입력한 out_features에 맞는 차원으로 선형 변환된 값이 반환되는 것이다.
Reference
https://pytorch.org/docs/stable/generated/torch.nn.Linear.html#linear
https://stackoverflow.com/questions/54916135/what-is-the-class-definition-of-nn-linear-in-pytorch