Softmax 역전파 공식 메모리 최적화

 

안녕하세요? 삼각형입니다.

GPU에서 Softmax 역전파를 구현할 때, 메모리 부족으로 인해 어려움을 겪을 수 있습니다. Softmax를 \(P_i = softmax(S_i)\)로 정의했을 때, 역전파 공식은 다음과 같습니다.

\[dS_i = (diag(P_i) - P_i P_i^T)dP_i\]

이 공식의 주요 문제점은 대각 행렬의 사용에 있습니다. 대각 행렬은 상당한 양의 메모리를 필요하기 때문에, GPU에서 이를 구현하는 것은 실질적으로 불가능합니다. 그러나 계산 순서를 조정함으로써 메모리 부족 문제를 해결할 수 있습니다. 메모리 사용량이 적은 Softmax 역전파 공식은 다음과 같습니다.

\[dS_i = P_i \circ dP_i - (P_i^T dP_i)P_i\]

이 공식을 이해하기 위해 수학적 증명 대신 실제 예제를 살펴보겠습니다.

\[P_i=\begin{bmatrix}0.1\\0.5\\0.4\end{bmatrix} dP_i=\begin{bmatrix}a\\b\\c \end{bmatrix}\]

먼저 Softmax 역전파 공식에 \(P_i\)와 \(dP_i\)를 대입하면 다음과 같습니다.

\[\begin{align*} dS_i &= (\begin{bmatrix}0.1&0&0\\0&0.5&0\\0&0&0.4\end{bmatrix} - \begin{bmatrix}0.1\\0.5\\0.4\end{bmatrix}\begin{bmatrix}0.1&0.5&0.4\end{bmatrix} )dP_i \newline &= (\begin{bmatrix}0.1&0&0\\0&0.5& 0\\0&0&0.4\end{bmatrix} - \begin{bmatrix} 0.01&0.05&0.04\\0.05&0.25&0.2\\0.04&0.2&0.16\end{bmatrix})\begin{bmatrix}a\\b\\c \end{bmatrix} \newline &= \begin{bmatrix}0.09a - 0.05b - 0.04c\\-0.05a + 0.25b - 0.2c\\ -0.04a - 0.2b - 0.24c\end{bmatrix} \end{align*}\]

이 계산 과정에서는 두 개의 \(n^2\) 크기의 행렬이 사용되며, 이로 인해 상당한 양의 메모리가 소요됩니다. 메모리를 절약하기 위해서는 \(n^2\) 크기의 행렬 사용을 피해야 합니다. 이제 메모리 사용량이 적은 Softmax 역전파 공식에 \(P_i\)와 \(dP_i\)를 대입하여 동일한 결과가 나오는지 확인해 보도록 하겠습니다.

\[\begin{align*} dS_i &= \begin{bmatrix}0.1\\0.5\\0.4\end{bmatrix} \circ \begin{bmatrix}a\\b\\c \end{bmatrix} - (\begin{bmatrix}0.1&0.5&0.4\end{bmatrix}\begin{bmatrix}a\\b\\c \end{bmatrix})\begin{bmatrix}0.1\\0.5\\0.4\end{bmatrix} \newline &= \begin{bmatrix}0.1a\\0.5b\\0.4c\end{bmatrix} - (0.1a+0.5b+0.4c) \begin{bmatrix}0.1\\0.5\\0.4\end{bmatrix} \newline &= \begin{bmatrix}0.09a - 0.05b - 0.04c\\-0.05a + 0.25b - 0.2c\\ -0.04a - 0.2b - 0.24c\end{bmatrix} \end{align*}\]

두 공식의 결과가 동일한 것을 확인할 수 있습니다. 또한 메모리 사용량이 적은 Softmax 역전파 공식을 잘 정리하면 계산을 더 간소화할 수 있습니다.

\[\begin{align*} dS_i &= \begin{bmatrix}0.1a\\0.5b\\0.4c\end{bmatrix} - (0.1a+0.5b+0.4c) \begin{bmatrix}0.1\\0.5\\0.4\end{bmatrix} \newline &= \begin{bmatrix}0.1\\0.5\\0.4\end{bmatrix} \circ (\begin{bmatrix}a\\b\\c \end{bmatrix} - \begin{bmatrix}0.1&0.5&0.4\end{bmatrix}\begin{bmatrix}a\\b\\c \end{bmatrix}) \end{align*}\]

이 결과를 기호를 사용하면, 다음과 같은 형태로 나타낼 수 있습니다.

\[\begin{align*} dS_i &= P_i \circ dP_i - D_iP_i \newline &= P_i(dP_i - D_i) \newline D_i &= P_i^T dP_i \end{align*}\]

이 방법을 통해 메모리 사용량을 줄이면서도 Softmax 역전파를 계산을 수행할 수 있습니다.

참고

감사합니다.