AI/머신러닝

[ML] 다중 클래스 분류(Multi-Class Classification) 정리

daeunnniii 2023. 5. 12. 23:50
728x90
반응형

이전 로지스틱 회귀 게시물에서는 독립 변수 $ x $가 1개인 이진 분류(Binary Classification)를 다루었다.

이번에는 독립 변수 $ x $가 2개 이상인 다중 클래스 분류(Multi-class Classification)와 소프트맥스 회귀(Softmax Regression)에 대해 정리할 것이다.

https://daeunnniii.tistory.com/194

 

[ML] 로지스틱 회귀(Logistic Regression) 쉽게 이해하기 & Pytorch 구현

로지스틱 회귀(Logistic Regression) 로지스틱 회귀를 알아보기 위해 선형회귀와 마찬가지로 독립 변수 $ x $가 1개인 이진 분류(Binary Classification)에 대해 정리하고 Pytorch로 구현해본 뒤 독립 변수 $ x $

daeunnniii.tistory.com

 

다중 클래스 분류(Multi-Class Classification)

Binary Classification의 경우에는 스팸 메일인가/아닌가와 같이 보통 True (0) / False (1)로 답이 나뉘는 문제였다. Multiclass Classification의 경우에는 정답이 3가지 이상이고 k개의 클래스에 해당하는 각 확률값을 결과로 낸다. 예를 들어 친구가 보낸 메일인지, 직장으로부터 온 메일인지, 결제 관련 메일인지, 기타 취미 활동에서 온 메일인지를 분류하고자 하는 경우 $ y $ 결과값이 1, 2, 3, 4의 4가지로 나타낼 수 있고 y=1은 0.1, y=2은 0.6, y=3은 0.2, y=4는 0.1과 같이 각 클래스의 확률값이 결과이다. 그리고 당연하게도 확률값의 총 합은 1이다. 

 

소프트맥스 회귀(Softmax Regression)

로지스틱 회귀의 활성화 함수는 시그모이드 함수(sigmoide function)이고, 이진 분류(Binary Classification)을 위한 것이었다면, 소프트맥스 회귀의 활성화 함수는 소프트맥스 함수(Softmax function)이고, 다중 분류(Multiclass Classification)을 위한 것이다.

 

1. 소프트맥스 함수(Softmax Regression)

1) 시그모이드 함수의 문제점

  • 값이 조금만 커지거나 작아져도 0 또는 1에 수렴한다. (saturation)
  • 시그모이드 함수로 다중분류를 진행할 경우 각각의 node에 대해 상대적인 평가가 불가능하다. (결과가 두가지로 나눠지므로)

소프트맥스 함수는 시그모이드 함수의 문제점을 보완하여 saturation 문제가 발생하지 않으며, 각 node에 대한 상대적인 평가가 가능하다. 

 

2) 소프트맥스 함수 원리

k개의 클래스로 분류하기 위해 다중 분류를 한다고 가정할 때,

다중 분류의 첫 번째 단계로, 가중치 연산을 진행한다.

가중치 연산을 통해 행렬곱을 진행하여 $ \hat{Y} $를 얻었다고 하자. 이 $ \hat{Y} $은 k차원의 벡터이며 k차원의 벡터 $ [z_1, z_2, ... , z_k] $ 에서 i번째 원소를 $ z_i $라고 할 때, 클래스 i에 속할 확률을 의미한다. 예를 들어 $ z = [0.3, 0.8, 0.2] $일 때 첫번째 클래스일 확률은 30%, 두번째 클래스일 확률은 80%, 세번째 클래스일 확률은 20%임을 나타낸다.

 

그 다음 단계로, 소프트맥스 함수는 예측값에 해당하는 k차원의 벡터 $ z $를 총합이 1인 확률 분포로 바꾸는 역할을 한다.

i번째 클래스가 정답일 확률을 $ p_i $로 나타낸다고 하자. 이때 소프트맥스 함수는 $ p_i $를 다음과 같이 정의한다.

만약 k=3일 경우 3차원의 벡터 $ z = [z_1, z_2, z_3] $의 입력을 받으면 소프트맥스 함수는 다음과 같은 출력을 리턴한다. $ p1, p2, p3 $는 각각 클래스1이 정답일 확률, 클래스2가 정답일 확률, 클래스3이 정답일 확률을 나타낸다.

 

$ softmax(z) = [\frac{e^{z_1}}{\sum_{j=1}^{3}e^{z_j}} \frac{e^{z_2}}{\sum_{j=1}^{3}e^{z_j}} \frac{e^{z_3}}{\sum_{j=1}^{3}e^{z_j}}] = [p_1, p_2, p_3] = [p_{virginica}, p_{setosa}, p_{versicolor}] = \hat{y} = $ 예측값

 

즉, 분류하고하자 하는 클래스가 k개일 때, k차원의 벡터를 입력받아 모든 벡터의 원소의 값을 0과 1 사이의 값으로 변경하여 다시 k차원의 벡터를 리턴한다고 정리할 수 있다.

 

2. 비용함수(Cost function): Cross Entropy

소프트맥스 회귀에서는 비용 함수로 크로스 엔트로피 함수(Cross Entropy Function)를 사용한다.

Entropy란 불확실성(uncertainty) 에 대한 척도이다. 크로스 엔트로피 함수 $ cost(W) $는 다음과 같다.

$ y $는 실제값이고 $ y_j $는 실제값의 원-핫 벡터의 j번째 인덱스에 해당한다. $ k $는 클래스의 개수, $ p_j $는 샘플 데이터가 j번째 클래스일 확률을 나타낸다. $ p_j $는 $ \hat{y_j} $로 표현하기도 한다.

이해하기 쉬운 예시로, 동전 던지기와 주사위 던지기 두 상황을 고려해볼 것이다. 동전을 던졌을 때 앞/뒷면이 나올 확률을 모두 1/2이고, 주사위를 던졌을 때 각 6면이 나올 확률을 모두 1/6이라고 하자. 두 상황에서 불확실성이 높은 것(즉, 어떤 데이터가 나올지 예측하기 어려운 것)은 주사위라는 것을 알 수 있다. 이를 수식으로 계산하면 다음과 같다.

  • $ H(x) = -(\frac{1}{2}\log{\frac{1}{2}} + \frac{1}{2}\log{\frac{1}{2}}) = \log{2} \approx 0.693 $
  • $ H(x) = -(\frac{1}{6}\log{\frac{1}{6}} + \frac{1}{6}\log{\frac{1}{6}} + \frac{1}{6}\log{\frac{1}{6}} + \frac{1}{6}\log{\frac{1}{6}} + \frac{1}{6}\log{\frac{1}{6}} + \frac{1}{6}\log{\frac{1}{6}}) = \log{6} \approx 1.791 $

n개의 전체 데이터에 대한 평균을 구하는 최종 비용함수는 다음과 같다.

 

위 식은 앞서 로지스틱 회귀에서 본 Binary Cross Entropy와 본질적으로는 동일한 함수식이다. 실제값 $ y=[0, 0, 1, 0, ..., 0] $일 때 $ a $가 실제값 원-핫 벡터에서 원소가 1인 3번째 인덱스이고 $ p_a = 1 $이라고 한다면, $ \hat{y} $이 $ y $를 정확하게 예측한 경우가 된다. 따라서 위 식에 대입하면 $ -(0 + 0 + 1\log(1) + ... + 0) = 0 $이 되기 때문에 결과적으로 정확하게 예측한 경우 크로스 엔트로피 함수의 값은 0이 된다.

 

다중 분류 예제

붓꽃 품종 분류 문제를 예제로 다중 분류를 진행해볼 것이다.

이는 4개의 특성(feature) 꽃받침 길이(SepalLength), 꽃받침 넓이(SepalWidth), 꽃잎 길이(PetalLength), 꽃잎 넓이(PetalWidth)로부터 setosa, versicolor, virginica 3가지 붓꽃 품종 중 어떤 품종인지를 예측하는 문제로 전형적인 다중 클래스 분류 문제이다.

 

다음과 같이 샘플 데이터가 5개가 있다고 하자.

하나의 샘플 데이터는 4개의 독립변수 $ x $를 가지며 이는 모델이 4차원 벡터를 입력으로 받음을 의미한다. 그런데 소프트맥스 함수의 입력으로 사용되는 벡터는 분류하고자하는 클래스의 개수가 되어야하므로 가중치 연산을 통해 붓꽃 품종의 개수 3인 3차원 벡터로 변환되어야한다.

이미지 출처: 위키독스(Wikidocs)

샘플 데이터 벡터를 소프트맥스 함수의 입력으로 축소하는 방법은 간단한다.

소프트맥스 함수의 입력벡터 z의 차원수만큼 결과값이 나오도록 가중치 곱을 진행하면 된다. 위 그림에서 화살표는 총 4 x 3 = 12개이며 전부 다른 가중치를 가지고, 학습 과정에서 점차적으로 오차를 최소화하는 가중치로 값이 변경된다.

 

전체 샘플 개수가 5개, 특성이 4개이므로 아래와 같이 5 x 4 행렬 $ X $를 정의할 수 있다.

그리고 분류하고자 하는 클래스는 3개이므로 가설의 예측값으로 얻는 행렬 $ \hat{Y} $의 열의 개수는 3개이어야 하고, 행의 크기는 $ X $와 동일해야한다. $ \hat{Y} $은 5 x 3 행렬이 될 것이고, $ \hat{Y} $은 5 x 4 입력 행렬 $ X $와 가중치 행렬 $ W $의 곱으로 얻어지는 행렬이므로 가중치 행렬 $ W $의 크기는 4 x 3의 크기를 가진 행렬임을 알 수 있다. 결과적으로 아래와 같이 나타낼 수 있다.

 

가설: $ H(X) = \hat{Y} = softmax(WX + B) $

이미지 출처: 위키독스(Wikidocs)

 

변환한 3차원의 벡터를 소프트맥스 함수의 입력으로 넣으면 이에 대한 출력은 분류하고자하는 클래스의 개수인 3차원을 가지는 벡터로, 각 원소는 0과 1사이의 값을 가진다. 이는 각각 특정 클래스가 정답일 확률을 나타낸다. 첫번째 원소인 $ p_1 $은 붓꽃 품종이 virginica가 정답일 확률, $ p_2 $는 setosa가 정답일 확률, $ p_3 $는 versicolor가 정답일 확률을 나타낸다.

따라서 각 클래스의 확률 중에서 가장 높은 확률을 갖는 클래스가 출력되는 것이 소프트맥스 회귀의 작동 방식이다.

 

그렇다면 이제 예측값과 실제값을 비교해서 오차를 구해 가중치를 업데이트하는 과정이 진행될 것이라고 예상할 수 있다. 소프트맥스 회귀에서는 출력의 종류를 분류하고자 하는 클래스의 개수만큼 늘리기 위해 one-hot encoding 기법을 사용한다. one-hot encoding은 출력값을 하나의 스칼라 값으로 내보내지 않고, 여러 값을 담은 벡터로 내보내는 기법이다. 따라서 실제값을 원-핫 벡터로 표현한다.

  • virginica의 원-핫 벡터: [1, 0, 0]
  • setosa의 원-핫 벡터: [0, 1, 0]
  • versicolor의 원-핫 벡터: [0, 0, 1]

 

현재 풀고 있는 샘플 데이터의 실제값이 setosa이고 원-핫 벡터는 [0, 1, 0]라고 하자. 소프트맥스 함수의 출력 벡터는 [0.26, 0.71, 0.04]이므로 오차 함수 Cross Entropy 함수에 적용하면 $ -(0 + 1\log{(0.71)} + 0) \approx 0.342 $이다.

 

이미지 출처: 위키독스(Wikidocs)

이후 역전파를 수행하여 비용함수에 대해 가중치 업데이트를 진행한다.

$ W^+ = W - \alpha \frac{\partial C}{\partial W} $ 

 

  • 역전파 과정 참고:

https://daeunnniii.tistory.com/189

 

역전파와 경사하강법 쉽게 이해하기

역전파(BackPropagation)는 한마디로 신경망 모델에서 오차를 이용하여 가중치를 업데이트하는 방법이다. 1. 인공 신경망의 이해 예제로 사용할 인공 신경망은 다음과 같이 입력층, 은닉층, 출력층 3

daeunnniii.tistory.com

 

이진 분류 & 다중 분류 정리

  이진 분류(Binary Classification) 다중 분류(Multiclass Classification)
출력값($ z = W^TX + b $) 상수 길이가 k인 배열을 출력($ z_1, z_2, ... , z_k $)
활성화 함수 출력값($ \hat{y} $) sigmoid ($ \sigma(z)=\frac{1}{1+e^{-z}} $) softmax ($ \frac{e^{z_{1}}}{\sum_{i=1}^{k}e^{z_{i}}}, \cdots, \frac{e^{z_{k}}}{\sum_{i=1}^{k}e^{z_{i}}} $)
비용 함수 Binary Cross Entropy
$ -\frac{1}{n}\sum_{i=1}^{n}(y^{(i)}\log\hat{y}^{(i)}+(1-y^{(i)})\log(1-\hat{y}^{(i)})) $
Cross Entropy
$ -\frac{1}{n}\sum_{i=1}^{n}y^{(i)}\cdot \log{\hat{y}^{(i)}} $

 

 

참고:

- 위키독스 Pytorch로 시작하는 딥러닝 입문

- 위키독스 딥 러닝을 이용한 자연어 처리 입문

728x90
반응형