2017년 1월 24일 화요일

분류 오차에 Cross entropy를 사용하는 이유

신경망 분류 오차를 줄이기 위한 최적화는 Net 출력 결과와 사용자 Label 정보 차이를 Error 정의한 후, 값을 줄이도록 Net 파라메터를 바꾸어 나가는 것이다

오차로는 분류 오차(classification error) 평균제곱 오차(MSE: mean square error) 일반적으로 생각할 있는 것들이다.
하지만 이들 보다 평균 Cross Entropy 오차(ACE: Averaged cross entropy error)를 빈번하게 사용하며 이유가 있다


적절히 학습된 Net 2개 있다고 하자. 부류(class) A, B, C 3개라고 했   두 넷이 주는 결과는 아래와 같다고 가정한다.


첫번째 넷이 주는 계산 결과:
계산결과         | 라벨(A/B/C)           | correct?
-----------------------------------------------
0.3  0.3  0.4  | 0  0  1 (A)          | yes
0.3  0.4  0.3  | 0  1  0 (B)          | yes
0.1  0.2  0.7  | 1  0  0 (C)          | no
분류 오차(classification error)를 계산해 보면 사용된 3개 샘플 중에 1개가 라벨과 일치하지 않으므로 1/3=0.33이다. 또한 분류 정확도(classification accuracy)는 2/3=0.67이다. 계산 결과를 보면 첫 샘플 2개는 겨우 맞추었고 세번째 샘플은 완전히 틀렸다.




두번째 넷이 주는 계산 결과:

계산결과         | 라벨(A/B/C)           | correct?
-----------------------------------------------
0.1  0.2  0.7  | 0  0  1 (A)          | yes
0.1  0.7  0.2  | 0  1  0 (B)          | yes
0.3  0.4  0.3  | 1  0  0 (C)          | no
첫번째 Net과 마찬가지로 분류 오차는 0.33이고, 분류 정확도는 0.67이다. 그러나 첫 두 샘플은 위 Net 보다 좀 더 확실히 맞추었고 세번째 샘플은 아깝게 틀렸다. 


위 두 Net을 비교하면서 분류 오차를 살펴 보면, 단순 분류 오차 계산은 틀린 개수에 대한 결과만 줄 뿐 라벨과 비교하여 얼마나 많이 틀렸는지, 얼마나 정확하게 맞았는지 그 정도에 대한 값을 제공하지 않는다. 



이와 비교하여 Cross entropy 오차를 계산해 보자. Cross entropy error의 정의

$-\sum_{i} y_i log(y_i^\prime)$

와 같다. $y_i$는 라벨값으로 one-hot vector로 주어지고, $y_i^\prime$는 넷 계산결과이다. 
첫번째 넷, 첫번째 샘플에 대해 계산해 보면 다음과 같다. 


-( (ln(0.3)*0) + (ln(0.3)*0) + (ln(0.4)*1) ) = -ln(0.4)


나머지 두 샘플 모두에 대해 계산하고 평균하면       


-(ln(0.4) + ln(0.4) + ln(0.1)) / 3 = 1.38

이다. 두번쨰 넷에 대해 평균 cross entropy를 계산하면


-(ln(0.7) + ln(0.7) + ln(0.3)) / 3 = 0.64

가 된다. 두 넷의 결과를 비교해 보면 두번째 넷이 오차가 더 작음을 알 수 있다. 즉, 넷이 주는 분류 오차에 정확도가 고려되어 최적화 관점에서 어떤 넷이 더 잘 학습되었는지를 알 수 있다. 수식에서 $log$ 연산자가 그 역할을 한다. 



다음으로 평균 제곱오차에 대해 살펴 보자. 
첫번째 넷, 첫번째 샘플에 대해 제곱오차를 살펴보면


(0.3 - 0)^2 + (0.3 - 0)^2 + (0.4 - 1)^2 = 0.09 + 0.09 + 0.36 = 0.54


이고 나머지 두개의 샘플에 대해 계산하고 평균한 제곱오차를 계산하면 다음과 같다. 
(0.54 + 0.54 + 1.34) / 3 = 0.81

첫 두 샘플은 맞은 것이고 세번째는 틀린 것이다. 제곱오차 크기는 세번째가 가장 크다.


두번쨰 넷에 대해서도 유사하게 계산하면
(0.14 + 0.14 + 0.74) / 3 = 0.34
이다. 


두 넷에 대한 계산 결과에서 보듯이 MSE는 틀린 샘플에 대해 더 집중하는 특성을 가진다. 맞은 것과 틀린 것에 똑같이 집중해야 하는데 그렇지 않아 오차 정의로는 적절하지 않다.



학습 과정 동안 나타나는 평균 제곱 오차(MSE)와 교차 엔트로피 오차(ACE)를 비교해 보자.
역 전파 학습 중에 목표 값(label)에 따라 출력 노드 값을 1.0 또는 0.0으로 설정하려고 한다.

이 때, MSE를 사용하면 가중치 계산에서 기울기 값에 (output) * (1 - output)이라는 조정 요소가 포함된다. 계산 된 출력이 0.0 또는 1.0에 가깝거나 가까워짐에 따라 (output) * (1 - output)의 값은 점점 작아진다.
예를 들어 output = 0.6이라면 (output) * (1 - output) = 0.24이지만 출력이 0.95이면 (output) * (1 - output) = 0.0475이다. 조정 요소가 점점 작아지면서 가중치 변화도 점점 작아지고 학습 진행이 멈출 수 있다.

그러나 ACE를 사용하면 (output) * (1 - output) 항이 사라진다. 따라서 가중치 변화는 점점 작아지거나 하지 않으므로 학습이 멈추거나 하지 않는다.
(위 경우는 노드 Activation을 softmax로 했을 경우이다.)



참고 문헌
[1] J. M. McCaffrey의 블로그