우리의 모델은 child network weight w 뿐만이 아니라 architect parameter α 도 갖는다. 거기다가 우리는 이 α 에 대해서는 unrolled gradient 를 계산해야 하기 때문에 update 2 가 상당히 복잡해진다. 먼저 식을 정리해보자.
Unrolled gradient 를 계산하기 위해 알아야 할 것이 두 가지 있다.
다변수 함수의 미분. 정확히 말하면 다변수 함수의 미분은 보통 우리가 하던 partial derivative 로 하는거고, 여기서 하는 다변수 함수의 미분이란 합성함수의 미분에 가깝다. 각 변수들이 사실은 하나의 변수의 함수인 것.
Single variable:
ddtf(x(t))=∂f∂x∂x∂tMulti variable:
ddtf(x(t),y(t))=∂f∂x∂x∂t+∂f∂y∂y∂t임의의 함수 f(x) 가 x=a 에서 무한 번 미분 가능하다면,
f(x)=f(a)+f′(a)(x−a)+f″(a)2!(x−a)2+⋯=∞∑k=0f(k)(a)k!(x−a)k가 x=a 근처에서 성립한다.
테일러 급수로 근사한 식을 p(x) 라 한다면 (f(x)=p∞(x)), f′(a)=p′(a), f′‘(a)=p′‘(a), … 이 성립한다. 즉, 테일러 급수는 x=a 에서 동일한 미분계수를 갖는 함수로 근사하는 방법. 위키를 참고하면 실제로 근사가 진행되면서 (고차 미분이 더해지면서) 점점 정확해지는 것을 볼 수 있다.
위 식을 다변수 함수로 확장하면, n차 테일러 급수는:
T(n)[f,a](x)=n∑k=0(∂kxf)(a)k!(x−a)k라 쓸 수 있다.
위 식은 임성빈 박사님의 포스트에서 가져왔지만 n차 표기는 내 마음대로 덧붙였다. 수학적 convention 이 아니니 주의하자.
주의) 여기서 ∇f(x)=(∇f)(x) 다. 만약 ∇(f(x)) 의 표기가 필요할 경우 따로 표기. 혼동의 소지가 있으나 논문과 동일한 방식을 따름.
update 2 의 식:
Lval(w−ξ∇wLtrain(w,α),α)에서, w′=w−ξ∇wLtrain(w,α) 라 두고 α 에 대해 multivariable chain rule 을 통해 그라디언트를 계산하면:
∇α[Lval(w′,α)]=∇αLval(w′,α)−ξ∇2α,wLtrain(w,α)∇w′Lval(w′,α)가 된다. 여기서 이 뒤의 헤시안 항을 finite difference 로 근사할 수 있는데:
∇2α,wLtrain(w,α)∇w′Lval(w′,α)≈∇αLtrain(w+,α)−∇αLtrain(w−,α)2ϵwherew+=w+ϵ∇w′Lval(w′,α)w−=w−ϵ∇w′Lval(w′,α)ϵ=small scalar가 되어, 연산 복잡도가 O(|α||w|) 에서 O(|α|+|w|) 로 줄어든다. 또한 여기서 추가로 등장하는 두 하이퍼파라메터 ξ 와 ϵ 에 대한 실험적인 설정값도 논문에서 (경험적으로) 제공한다. virtual gradient step 의 learning rate 인 ξ 는 w 의 learning rate 와 동일하게 사용하며, 엡실론의 경우 ϵ=0.01/||∇w′Lval(w′,α)|| 를 사용한다.
그렇다면, 최종적으로 unrolled gradient (virtual step gradient) 식을 정리해 보자.
∇α[Lval(w−ξ∇wLtrain(w,α),α)]=∇αLval(w′,α)−ξ∇2α,wLtrain(w,α)∇w′Lval(w′,α)≈∇αLval(w′,α)−∇αLtrain(w+,α)−∇αLtrain(w−,α)2ϵ이제 전부 다 넣어서 풀면 다음과 같다:
∇αLval(w′,α)−∇αLtrain(w+ϵ∇w′Lval(w′,α),α)−∇αLtrain(w−ϵ∇w′Lval(w′,α),α)2⋅0.01/||∇w′Lval(w′,α)||임성빈 박사님의 포스트를 상당 부분 참조하여 작성
식 (7) 로부터 시작하자:
∇2α,wLtrain(w,α)∇w′Lval(w′,α)≈∇αLtrain(w+,α)−∇αLtrain(w−,α)2ϵ여기서 오른쪽의 분자 항을 테일러 시리즈로 근사할 수 있다. 그러면 이 때
f(w)=∇αLtrain(w,α)라 하면,
f(w)≈T(2)[f,w](x)가 x=w 근처에서 성립하고,
T(2)[f,w](x)=f(w)+∇wf(w)(w−x)+12∇2wf(w)(x−w)2=∇αLtrain(w,α)+∇w∇αLtrain(w,α)(x−w)+12∇2w∇αLtrain(w,α)(x−w)2가 된다. 여기서 이 Taylor series 함수에 w+ 와 w− 를 넣어 빼주면, |w+−w|=|w−−w| 이므로 첫번째와 세번째 항이 사라진다. 그러면:
T(2)[∇αLtrain,w](w+)−T(2)[∇αLtrain,w](w−)=∇w∇αLtrain(w,α)(w+−w−)=2ϵ∇2α,wLtrain(w,α)∇w′Lval(w′,α)이므로 식 (7) 을 얻을 수 있다. 마지막 전개는 w+−w−=2ϵ∇w′Lval(w′,α) 이기 때문이다.
세번째 항 끼리 뺄 때 중간에 Hessian 이 들어가서 다소 복잡하지만 잘 풀어서 빼 보면 사라지는 것을 확인할 수 있다.