Softmax 역전파 공식 메모리 최적화

 

안녕하세요? 삼각형입니다.

GPU에서 Softmax 역전파를 구현할 때, 메모리 부족으로 인해 어려움을 겪을 수 있습니다. Softmax를 Pi=softmax(Si)로 정의했을 때, 역전파 공식은 다음과 같습니다.

dSi=(diag(Pi)PiPiT)dPi

이 공식의 주요 문제점은 대각 행렬의 사용에 있습니다. 대각 행렬은 상당한 양의 메모리를 필요하기 때문에, GPU에서 이를 구현하는 것은 실질적으로 불가능합니다. 그러나 계산 순서를 조정함으로써 메모리 부족 문제를 해결할 수 있습니다. 메모리 사용량이 적은 Softmax 역전파 공식은 다음과 같습니다.

dSi=PidPi(PiTdPi)Pi

이 공식을 이해하기 위해 수학적 증명 대신 실제 예제를 살펴보겠습니다.

Pi=[0.10.50.4]dPi=[abc]

먼저 Softmax 역전파 공식에 PidPi를 대입하면 다음과 같습니다.

dSi=([0.10000.50000.4][0.10.50.4][0.10.50.4])dPi=([0.10000.50000.4][0.010.050.040.050.250.20.040.20.16])[abc]=[0.09a0.05b0.04c0.05a+0.25b0.2c0.04a0.2b0.24c]

이 계산 과정에서는 두 개의 n2 크기의 행렬이 사용되며, 이로 인해 상당한 양의 메모리가 소요됩니다. 메모리를 절약하기 위해서는 n2 크기의 행렬 사용을 피해야 합니다. 이제 메모리 사용량이 적은 Softmax 역전파 공식에 PidPi를 대입하여 동일한 결과가 나오는지 확인해 보도록 하겠습니다.

dSi=[0.10.50.4][abc]([0.10.50.4][abc])[0.10.50.4]=[0.1a0.5b0.4c](0.1a+0.5b+0.4c)[0.10.50.4]=[0.09a0.05b0.04c0.05a+0.25b0.2c0.04a0.2b0.24c]

두 공식의 결과가 동일한 것을 확인할 수 있습니다. 또한 메모리 사용량이 적은 Softmax 역전파 공식을 잘 정리하면 계산을 더 간소화할 수 있습니다.

dSi=[0.1a0.5b0.4c](0.1a+0.5b+0.4c)[0.10.50.4]=[0.10.50.4]([abc][0.10.50.4][abc])

이 결과를 기호를 사용하면, 다음과 같은 형태로 나타낼 수 있습니다.

dSi=PidPiDiPi=Pi(dPiDi)Di=PiTdPi

이 방법을 통해 메모리 사용량을 줄이면서도 Softmax 역전파를 계산을 수행할 수 있습니다.

참고

감사합니다.