Processing math: 100%

Khanrc

AutoML (3) - DARTS: math

  1. Introduction of AutoML and NAS
  2. DARTS
  3. DARTS: mathmatics
  4. DARTS: tutorial
  5. DARTS: multi-gpu extension

Mathmatics in DARTS

Algorithm

우리의 모델은 child network weight w 뿐만이 아니라 architect parameter α 도 갖는다. 거기다가 우리는 이 α 에 대해서는 unrolled gradient 를 계산해야 하기 때문에 update 2 가 상당히 복잡해진다. 먼저 식을 정리해보자.

Preliminary

Unrolled gradient 를 계산하기 위해 알아야 할 것이 두 가지 있다.

The Multivariable Chain Rule

다변수 함수의 미분. 정확히 말하면 다변수 함수의 미분은 보통 우리가 하던 partial derivative 로 하는거고, 여기서 하는 다변수 함수의 미분이란 합성함수의 미분에 가깝다. 각 변수들이 사실은 하나의 변수의 함수인 것.

Single variable:

ddtf(x(t))=fxxt

Multi variable:

ddtf(x(t),y(t))=fxxt+fyyt

Taylor Series

임의의 함수 f(x)x=a 에서 무한 번 미분 가능하다면,

f(x)=f(a)+f(a)(xa)+f(a)2!(xa)2+=k=0f(k)(a)k!(xa)k

x=a 근처에서 성립한다.

Intuition

테일러 급수로 근사한 식을 p(x) 라 한다면 (f(x)=p(x)), f(a)=p(a), f(a)=p(a), … 이 성립한다. 즉, 테일러 급수는 x=a 에서 동일한 미분계수를 갖는 함수로 근사하는 방법. 위키를 참고하면 실제로 근사가 진행되면서 (고차 미분이 더해지면서) 점점 정확해지는 것을 볼 수 있다.

Multivariable function

위 식을 다변수 함수로 확장하면, n차 테일러 급수는:

T(n)[f,a](x)=nk=0(kxf)(a)k!(xa)k

라 쓸 수 있다.

위 식은 임성빈 박사님의 포스트에서 가져왔지만 n차 표기는 내 마음대로 덧붙였다. 수학적 convention 이 아니니 주의하자.

Unrolled gradient

주의) 여기서 f(x)=(f)(x) 다. 만약 (f(x)) 의 표기가 필요할 경우 따로 표기. 혼동의 소지가 있으나 논문과 동일한 방식을 따름.

update 2 의 식:

Lval(wξwLtrain(w,α),α)

에서, w=wξwLtrain(w,α) 라 두고 α 에 대해 multivariable chain rule 을 통해 그라디언트를 계산하면:

α[Lval(w,α)]=αLval(w,α)ξ2α,wLtrain(w,α)wLval(w,α)

가 된다. 여기서 이 뒤의 헤시안 항을 finite difference 로 근사할 수 있는데:

2α,wLtrain(w,α)wLval(w,α)αLtrain(w+,α)αLtrain(w,α)2ϵwherew+=w+ϵwLval(w,α)w=wϵwLval(w,α)ϵ=small scalar

가 되어, 연산 복잡도가 O(|α||w|) 에서 O(|α|+|w|) 로 줄어든다. 또한 여기서 추가로 등장하는 두 하이퍼파라메터 ξϵ 에 대한 실험적인 설정값도 논문에서 (경험적으로) 제공한다. virtual gradient step 의 learning rate 인 ξw 의 learning rate 와 동일하게 사용하며, 엡실론의 경우 ϵ=0.01/||wLval(w,α)|| 를 사용한다.

그렇다면, 최종적으로 unrolled gradient (virtual step gradient) 식을 정리해 보자.

α[Lval(wξwLtrain(w,α),α)]=αLval(w,α)ξ2α,wLtrain(w,α)wLval(w,α)αLval(w,α)αLtrain(w+,α)αLtrain(w,α)2ϵ

이제 전부 다 넣어서 풀면 다음과 같다:

αLval(w,α)αLtrain(w+ϵwLval(w,α),α)αLtrain(wϵwLval(w,α),α)20.01/||wLval(w,α)||

Hessian term - eq (7)

임성빈 박사님의 포스트를 상당 부분 참조하여 작성

식 (7) 로부터 시작하자:

2α,wLtrain(w,α)wLval(w,α)αLtrain(w+,α)αLtrain(w,α)2ϵ

여기서 오른쪽의 분자 항을 테일러 시리즈로 근사할 수 있다. 그러면 이 때

f(w)=αLtrain(w,α)

라 하면,

f(w)T(2)[f,w](x)

가 x=w 근처에서 성립하고,

T(2)[f,w](x)=f(w)+wf(w)(wx)+122wf(w)(xw)2=αLtrain(w,α)+wαLtrain(w,α)(xw)+122wαLtrain(w,α)(xw)2

가 된다. 여기서 이 Taylor series 함수에 w+w 를 넣어 빼주면, |w+w|=|ww| 이므로 첫번째와 세번째 항이 사라진다. 그러면:

T(2)[αLtrain,w](w+)T(2)[αLtrain,w](w)=wαLtrain(w,α)(w+w)=2ϵ2α,wLtrain(w,α)wLval(w,α)

이므로 식 (7) 을 얻을 수 있다. 마지막 전개는 w+w=2ϵwLval(w,α) 이기 때문이다.

세번째 항 끼리 뺄 때 중간에 Hessian 이 들어가서 다소 복잡하지만 잘 풀어서 빼 보면 사라지는 것을 확인할 수 있다.

References