ML || DL/이론

[Coursera] DLS_C3W2: Transfer Learning / Multi-task Learning / End-to-End Learning

junmukbap98 2023. 9. 26. 15:06

1. Transfer Learning

: 우리가 원하는 task의 (상대적으로 적은) dataset을 활용해서, 큰 dataset으로 pre-trained 된 NN를 retrain 해서 target task를 수행하는 것

 

만약, 우리가 radiology diagnosis task를 수행하려고 한다고 하자. 이때 관련 데이터셋이 100장 밖에 없다면 어떻게 효과적으로 diagnosis task를 수행하는 model을 만들 수 있을까?

이런 상황에서 생각해 볼 수 있는 것이 Transfer learning 혹은 fine-tuning이다. 

 

먼저, 매우 큰 image recognition dataset으로 모델을 학습한 다음, radiology diagnosis dataset을 활용해서

  • 앞의 layer들은 freeze 하고 마지막 몇 개의 layer의 weight만을 retrain 한다. (Transfer learning)
  • 혹은 pretrained networkd의 모든 layer들을 retrain 한다. (Fine-tuning)

(Transfer learing의 한 종류가 fine-tuning이라고 생각하면 된다.)

 

즉, image recognition에서 얻은 knowledge를 radiology diagnosis에 응용하거나 transfer 하는 것이다. 

이것이 효과적인 이유는 이미지를 인식하는데 edge나 curve 등의 low-level feature을 detect 하는 능력이 radiology diagnosis task에 도움이 될 수 있기 때문이다. 

 

이러한 transfer learning이 잘 동작하기 위해서는 세 가지 조건이 필요하다. (Task A로부터 Task B에 knowlege를 transfer 하는 경우에)

  • Task A와 B는 same input x를 가져야 한다. 즉 B의 input이 이미지라면, A의 input도 이미지여야 한다. 
  • Task B보다 Task A에 대해 더 많은 data를 가져야 한다. 
  • A로부터 얻은 low-level feature가 B를 학습하는데 도움이 되어야 한다. 즉, 완전히 상관없는 두 task에서 knowlege를 전달하는 것은 무의미할 수 있다. 어느 정도 비슷한, 연관성이 있는 task들끼리 transfer를 진행해야 한다. 

 


 

2. Multi-task Learning

: 한 개의 NN가 여러 task를 처리할 수 있도록 하는 것 (대표적으로 object recognition이 있다.)

 

예를 들어, 우리가 하나의 사진 내에서 보행자, 자동차, 표지판, 신호등을 detection 하는 task를 수행한다고 해보자. 

그러면, 우리의 NN은 총 4개의 object에 대해 검출이 됐는지 (1) 아닌지 (0)를 예측해야 한다. 

즉, $\hat{y}^{(i)}=(4,1)$의 크기를 갖는다. 

 

이러한 네트워크를 학습하기 위해서 우리는 다음과 같은 loss를 사용할 수 있다. 

$$ \frac{1}{m} \sum_{i=1}^{m} \sum_{j=1}^{4} L(\hat{y}_{j}^{(i)}, y_{j}^{(i)})$$

여기서 $L$은 보편적인 logistic loss이다. 

 

이러한 multi-task learning의 loss와 softmax regression과의 차이점은 하나의 이미지가 multiple labels를 가질 수 있다는 점이다. 

즉, 한 개의 이미지를 보고 4개의 문제를 푸는 하나의 NN를 학습하는 것이다. 

 

이때, NN의 earlier feature의 일부가 다양한 objects사이의 공유될 수 있는 경우 4개의 NN를 만드는 것보다 1개의 NN이 multi-task learning을 하도록 하는 것이 성능이 더 좋다. 

 

Multi-task learning이 잘 동작하기 위한 조건

  • Lower-level feature를 공유함으로써 이점을 얻을 수 있는 몇 개의 작업에 대해서 Training 해야 한다.
  • (Usually) 각 task에 대해 갖고 있는 data의 양이 비슷해야 한다.
  • 모든 tasks에 대해 잘 동작할 수 있도록 충분히 큰 NN를 학습해야 한다. 

 


 

3. End-to-End Learning

: End-to-End learning은 data processing system, learning system 등의 multiple satege들을 하나의 NN로 변환하는 것을 의미한다. 즉 model이 input X를 가지고 output Y를 도출해 내는 mapping function을 학습하는 것이다. 

 

Speech recognition 예시를 들어보면, 아래와 같다.

  • 전통적인 방법: audio (x) --> hand designed feature --> phonemes --> words ---> transcript (y)
  • End-to-End: audio (x) ---------------------------------------------------------> transcript (y)

이때, end-to-end learning이 잘 동작하기 위해서는 많은 양의 data가 필요하다. 가령, 전통적인 approach에서는 3,000h 정도의 데이터만 있어도 잘 동작하는데 end-to-end의 경우 10,000h~100,000h 정도의 데이터가 필요할 수 있다.

 

하지만, 이런 end-to end learning이 항상 효과적인 것은 아니다.  보통 multiple step으로 나누어서 문제를 해결하는 것이 더 효과적일 때도 있다. 예를 들어 Face recognition의 경우, 아래의 두 가지 접근법을 취해볼 수 있다. 

  • Two-step    :     Image (x) ----> Face detection ----> Position of person's face (and crop) ----> Identity (y)
  • End-to-End:     Image (x) ---------------------------------------------------------------------> Identity (y)

하지만 end-to-end 보다 각각의 문제를 sub-task로 나누었을 때 문제를 더 simple하게 해결할 수 있고, 

end-to-end를 위한 dataset보다 각 sub task들에 대한 dataset이 훨씬 수집하기 쉽다. 따라서 이런 two-step approach가 현재로써는 훨씬 더 효과적일 수 있다. 

 

우리가 end-to-end learning approach를 취할지 말지 결정하기 전에, 장단점을 살펴보자.

 

<장점>

  • Let the data speak: 데이터를 순수하게 반영할 수 있게 해준다. 즉, 인간의 선입견을 강제로 반영하기보다는 data 내의 어떠한 통계적 특성을 더 잘 capture 할 수 있다. 가령, speech recognition task에서 기계는 phonemes보다 더 효과적인 표현을 learning 해서 더 나은 성능을 보여줄 수도 있는 것이다. 
  • Less hand-designing of components needed: 즉 design work flow를 단순화 할 수 있다.

<단점>

  • May need large amout of data
  • Excludes potentially useful hand-designed components: 만약 데이터셋이 적을 경우, 알고리즘이 데이터셋에서 얻을 수 있는 인사이트가 부족할 수 있다. 이때 잘 설계된 hand-designed system을 사용하면 유용한 knowledge를 알고리즘에게 전달할 수 있는데, 이러한 것들이 모두 배제된다. 

이렇든 end-to-end learning은 장단점이 존재하고, 만병통치약처럼 쓸 수 있는 접근법은 아니다.

따라서, 우리는 end-to-end learning을 적용하려고 할때, 

"X를 Y에 매핑하는 데 필요한 (복잡한) 함수를 학습하기에 충분한 데이터가 있는지?"를 항상 확인 해야 한다.