치즈의 AI 녹이기

Metric Learning 본문

인공지능 대학원생의 생활/구글링

Metric Learning

개발자 치즈 2021. 7. 8. 16:25

오늘은 Metric Learning에 대해 다뤄보도록 합니다. 

 

데이터 point 간 거리를 측정하기 위한 방법론들(Euclidean, Cosine..)이 존재하고 있지만, 

특정 task, 또는 데이터에 맞는 distance metric이 필요하다는 취지로 metric learning이 등장하였습니다. 

 

따라서 metric learning이란, 기계 학습을 이용하여 데이터로부터 task-specific한 distance metric을 자동으로 구성하는 방법입니다. 그 결과를 k-NN classification, clustering 등에서 활용할 수 있습니다. 

 

metric learning 문제는 두 가지 타입의 데이터 성격을 전제로 합니다.

  • Supervised learning: 모든 데이터가 잘 라벨링되어 있는 경우, anchor sample을 기준으로 positive sample은 서로 가깝게, negative sample은 서로 멀게 학습합니다. 
  • Weakly supervised learning: 라벨이 명시되어 있지 않고 쌍으로 제공되는 경우, positive pair들을 서로 가깝게, negative pair들을 서로 멀게 학습합니다. 

* anchor : 중심점이 되는 데이터

* positive sample : anchor을 기준으로 anchor과 같은 클래스에 속한 sample

* negative sample : anchor을 기준으로 anchor과 다른 클래스에 속한 sample

 

metric-learn package에서 대부분 'Mahalanobis distance' 알고리즘을 사용합니다. 

Mahalanobis distance

 

여기서 L은 (num_dims, n_features) shape으로, n_features는 데이터 feature를 의미합니다.

Lx = (num_dims, n_features) x (n_features, 1) = (num_dims, 1) 가 되면서

num_dims size의 임베딩 공간을 학습할 수 있습니다. 

 

 

 

참고 링크 : http://contrib.scikit-learn.org/metric-learn/introduction.html