首页 > 机器学习 > Neural Tangent Kernel

Neural Tangent Kernel

From the last post, we know that a wide neural network with reasonable initialization could be seen as a Gaussian Porcess without training. After training, what can we say about wide NN?

Gradient Descent with Small LR

Suppose we have dataset $\left\{\left(\boldsymbol{x}_{i}, y_{i}\right)\right\}_{i=1}^{n} \subset \mathbb{R}^{d} \times \mathbb{R}$ and loss function $\ell(\boldsymbol{\theta})=\frac{1}{2} \sum_{i=1}^{n}\left(f\left(\boldsymbol{\theta}, \boldsymbol{x}_{i}\right)-y_{i}\right)^{2}$. Then, gradient descent is defined as
$$\boldsymbol\theta(t+1) = \boldsymbol\theta(t) – \eta\nabla_{\boldsymbol\theta}\ell(\boldsymbol\theta(t)))$$ We could write it as,
$${\boldsymbol\theta(t+1) – \boldsymbol\theta(t) \over \eta}=- \nabla_{\boldsymbol\theta}\ell(\boldsymbol\theta(t)))$$ As $\eta\to 0$,
$${d\boldsymbol\theta(t)\over dt}=-\nabla_{\boldsymbol\theta}\ell(\boldsymbol\theta(t)))$$ Let $$
\boldsymbol{u}(t)=\left[f\left(\boldsymbol{\theta}(t), \boldsymbol{x}_{i}\right)\right]_{i \in \{1,2,\ldots,n\}} \in \mathbb{R}^{n}
$$ be the network output, then
$${d\boldsymbol u\over dt}=-\boldsymbol H(t)(\boldsymbol u(t) -\boldsymbol y)$$ where
$$\boldsymbol H_{ij}(t) =\left\langle {\partial f(\boldsymbol \theta(t),\boldsymbol x_i)\over \partial \boldsymbol\theta}, {\partial f(\boldsymbol \theta(t),\boldsymbol x_j)\over \partial \boldsymbol\theta} \right\rangle$$ Proof

We have
$$
\begin{align}
{du_i\over dt}{d f\left(\boldsymbol{\theta}(t), \boldsymbol{x}_{i}\right)\over dt}&=\left\langle {\partial f\left(\boldsymbol{\theta}(t), \boldsymbol{x}_{i}\right)\over \partial \boldsymbol\theta}, {d\boldsymbol\theta(t)\over dt} \right\rangle \\
&=\left\langle {\partial f\left(\boldsymbol{\theta}(t), \boldsymbol{x}_{i}\right)\over \partial \boldsymbol\theta}, -\nabla_{\boldsymbol\theta}\ell(\boldsymbol\theta(t))) \right\rangle \\
&=\left\langle {\partial f\left(\boldsymbol{\theta}(t), \boldsymbol{x}_{i}\right)\over \partial \boldsymbol\theta}, -\sum_{j=1}^{n}\left(f\left(\boldsymbol{\theta}(t), \boldsymbol{x}_{j}\right)-y_{j}\right){\partial f\left(\boldsymbol{\theta}(t), \boldsymbol{x}_{j}\right)\over \partial \boldsymbol\theta} \right\rangle \\
&=-\sum_{j=1}^{n}\left(f\left(\boldsymbol{\theta}(t), \boldsymbol{x}_{j}\right)-y_{j}\right)\left\langle {\partial f\left(\boldsymbol{\theta}(t), \boldsymbol{x}_{i}\right)\over \partial \boldsymbol\theta}, {\partial f\left(\boldsymbol{\theta}(t), \boldsymbol{x}_{j}\right)\over \partial \boldsymbol\theta} \right\rangle \\
&=-\sum_{j=1}^{n}\left(u_j(t)-y_{j}\right)\boldsymbol H_{ij}(t)
\end{align}
$$ $\blacksquare$

When the neural network is infinitely wide, the $\boldsymbol H$ converges to a certain deterministic kernel matrix $\boldsymbol H^*$, which is called the Neural Tangent Kernel. It becomes,
$${d\boldsymbol u\over dt}=-\boldsymbol H^*(\boldsymbol u(t) -\boldsymbol y)$$

Neural Tangent Kernel

We will first define a neural network. Suppose the input data is $\boldsymbol x\in \mathbb R^d$, and $\boldsymbol{g}^{(0)}(\boldsymbol x)=\boldsymbol x$ and $d_0=d$ for convenience. We define $L$-hidden-layer fully connected neural network as

$$
\begin{align}
\boldsymbol{f}^{(h)}(\boldsymbol{x}) &=\boldsymbol{W}^{(h)} \boldsymbol{g}^{(h-1)}(\boldsymbol{x}) \in \mathbb{R}^{d_{h}} \\
\boldsymbol{g}^{(h)}(\boldsymbol{x}) &=\sqrt{\frac{c_{\sigma}}{d_{h}}} \sigma\left(\boldsymbol{f}^{(h)}(\boldsymbol{x})\right) \in \mathbb{R}^{d_{h}},\quad h=1,2,\ldots,L
\end{align}
$$

where

$\boldsymbol{W}^{(h)}\in \mathbb R^{d_h\times d_{h-1}}$ is the weight matrix of $h$-th layer, $\sigma:\mathbb R\to \mathbb R$ is an activation function, $c_{\sigma} = \left( \mathbb E_{z\sim \mathcal N(0,1)}\left[\sigma(z)^2\right] \right)^{-1}$. The last layer is

$$
f(\boldsymbol{\theta}, \boldsymbol{x}) =f^{(L+1)}(\boldsymbol{x})=\boldsymbol{W}^{(L+1)} \cdot \sqrt{\frac{c_{\sigma}}{d_{L}}} \sigma\left(\boldsymbol{W}^{(L)} \cdot \sqrt{\frac{c_{\sigma}}{d_{L-1}}} \sigma\left(\boldsymbol{W}^{(L-1)} \cdots \sqrt{\frac{c_{\sigma}}{d_{1}}} \sigma\left(\boldsymbol{W}^{(1)} \boldsymbol{x}\right)\right)\right)
$$

where $\boldsymbol{W}^{(L+1)}\in \mathbb R^{1\times d_{L}}$ is the weight of final layer. $\boldsymbol\theta = \left(\boldsymbol{W}^{(1)}),\ldots,\boldsymbol{W}^{(h)})\right)$ is the total parameters. All weights are initialized as i.i.d $\mathcal N(0,1)$. The scaling factor $\sqrt{{c_{\sigma}}/{d_{h}}}$ ensures the norm of $\boldsymbol{g}^{(h)}(\boldsymbol{x})$ to preserve its initialization.

From the last post, we know that the Neural Network could be defined as a Gaussian Process. Specifically, the pre-activations $\boldsymbol f^{(h)}(\boldsymbol x)$ could be seen as a centered gaussian process with covariance $\Sigma^{h-1}$ defined as
$$
\begin{align}
\Sigma^{(0)}\left(\boldsymbol{x}, \boldsymbol{x}^{\prime}\right) &=\boldsymbol{x}^{\top} \boldsymbol{x}^{\prime} \\
\boldsymbol{\Lambda}^{(h)}\left(\boldsymbol{x}, \boldsymbol{x}^{\prime}\right) &=\left(\begin{array}{cc}
\Sigma^{(h-1)}(\boldsymbol{x}, \boldsymbol{x}) & \Sigma^{(h-1)}\left(\boldsymbol{x}, \boldsymbol{x}^{\prime}\right) \\
\Sigma^{(h-1)}\left(\boldsymbol{x}^{\prime}, \boldsymbol{x}\right) & \Sigma^{(h-1)}\left(\boldsymbol{x}^{\prime}, \boldsymbol{x}^{\prime}\right)
\end{array}\right) \in \mathbb{R}^{2 \times 2}, \\
\Sigma^{(h)}\left(\boldsymbol{x}, \boldsymbol{x}^{\prime}\right) &=c_{\sigma} \underset{(u, v) \sim \mathcal{N}\left(\mathbf{0}, \boldsymbol{\Lambda}^{(h)}\right)}{\mathbb{E}}[\sigma(u) \sigma(v)]
\end{align}
$$

We will also define a derivative covariance

$$
\dot{\Sigma}^{(h)}\left(\boldsymbol{x}, \boldsymbol{x}^{\prime}\right)=c_{\sigma} \underset{(u, v) \sim \mathcal{N}\left(\mathbf{0}, \boldsymbol{\Lambda}^{(h)}\right)}{\mathbb{E}}[\dot{\sigma}(u) \dot{\sigma}(v)]
$$

We want to derive the neural tangent kernel, namely, $$\Theta^{(L)}(\boldsymbol x,\boldsymbol x’) =\left\langle {\partial f(\boldsymbol \theta(t),\boldsymbol x)\over \partial \boldsymbol\theta}, {\partial f(\boldsymbol \theta(t),\boldsymbol x’)\over \partial \boldsymbol\theta} \right\rangle$$

To get ${\partial f(\boldsymbol \theta(t),\boldsymbol x)\over \partial \boldsymbol\theta}$, we need ${\partial f(\boldsymbol \theta(t),\boldsymbol x)\over \partial \boldsymbol W^{(h)}}$. From the last layer, we have
$$
\begin{align}
f(\boldsymbol \theta(t),\boldsymbol x)&=\boldsymbol W^{(L+1)}\boldsymbol g^{(L)}(\boldsymbol x)\\
{df(\boldsymbol \theta(t),\boldsymbol x)\over d\boldsymbol W^{(L+1)}}&=\left(
\boldsymbol g^{(L)}(\boldsymbol x)\right)^\intercal
\end{align}
$$

Then

$$
\begin{align}
f(\boldsymbol \theta(t),\boldsymbol x)&=\boldsymbol W^{(L+1)}\sqrt{c_\sigma\over d_{L}}\sigma\left(\boldsymbol f^{(L)}(\boldsymbol x)\right)=\boldsymbol W^{(L+1)}\sqrt{c_\sigma\over d_{L}}\sigma\left(\boldsymbol W^{(L)}\boldsymbol g^{(L-1)}(\boldsymbol x)\right)\\
{df(\boldsymbol \theta(t),\boldsymbol x)\over d\boldsymbol W^{(L)}}&=\sum_i {\partial f(\boldsymbol \theta(t),\boldsymbol x)\over \partial \boldsymbol g^{(L)}_i(\boldsymbol x)}{\partial \boldsymbol g^{(L)}_i(\boldsymbol x)\over \partial \boldsymbol f^{(L)}_i(\boldsymbol x)}{\partial \boldsymbol f^{(L)}_i(\boldsymbol x)\over \partial\boldsymbol W^{(L)}}\\
&=\sqrt{c_\sigma\over d_{L}}\text{diag}\left(\dot\sigma\left(\boldsymbol f^{(L)}_i(\boldsymbol x)\right)\right)\left(\boldsymbol W^{(L+1)}\right)^\intercal\left(\boldsymbol g^{(L-1)}(\boldsymbol x)\right)^\intercal\\
&=\boldsymbol b^{(L)}(\boldsymbol x)\left(\boldsymbol g^{(L-1)}(\boldsymbol x)\right)^\intercal
\end{align}
$$

Here $\boldsymbol b^{(L)}(\boldsymbol x)=\sqrt{c_\sigma\over d_{L}}\text{diag}\left(\dot\sigma\left(\boldsymbol f^{(L)}_i(\boldsymbol x)\right)\right)\left(\boldsymbol W^{(L+1)}\right)^\intercal$

$$
\begin{align}
{df(\boldsymbol \theta(t),\boldsymbol x)\over d\boldsymbol W^{(L-1)}}&=\sum_i {\partial f(\boldsymbol \theta(t),\boldsymbol x)\over \partial \boldsymbol g^{(L)}_i(\boldsymbol x)}{\partial \boldsymbol g^{(L)}_i(\boldsymbol x)\over \partial \boldsymbol f^{(L)}_i(\boldsymbol x)}{\partial \boldsymbol f^{(L)}_i(\boldsymbol x)\over \partial\boldsymbol W^{(L-1)}}\\
&=\sum_i \boldsymbol b^{(L)}_i(\boldsymbol x){\partial \boldsymbol f^{(L)}_i(\boldsymbol x)\over \partial\boldsymbol W^{(L-1)}}\\
&=\sum_i \boldsymbol b^{(L)}_i(\boldsymbol x)\sqrt{c_\sigma\over d_{L-1}}\text{diag}\left(\dot\sigma\left(\boldsymbol f^{(L-1)}_i(\boldsymbol x)\right)\right)\left(\boldsymbol W^{(L)}\right)^\intercal\left(\boldsymbol g^{(L-2)}(\boldsymbol x)\right)^\intercal\\
&=\sqrt{c_\sigma\over d_{L-1}}\text{diag}\left(\dot\sigma\left(\boldsymbol f^{(L-1)}_i(\boldsymbol x)\right)\right)\left(\boldsymbol W^{(L)}\right)^\intercal\boldsymbol b^{(L)}(\boldsymbol x)\left(\boldsymbol g^{(L-2)}(\boldsymbol x)\right)^\intercal\\
&=\boldsymbol b^{(L-1)}(\boldsymbol x)\left(\boldsymbol g^{(L-2)}(\boldsymbol x)\right)^\intercal
\end{align}
$$

Here $\boldsymbol b^{(L-1)}(\boldsymbol x)=\sqrt{c_\sigma\over d_{L-1}}\text{diag}\left(\dot\sigma\left(\boldsymbol f^{(L-1)}_i(\boldsymbol x)\right)\right)\left(\boldsymbol W^{(L)}\right)^\intercal\boldsymbol b^{(L)}(\boldsymbol x)$

By induction, we have

$$\boldsymbol b^{(h)}(\boldsymbol x)=\begin{cases}
1&h=L+1\\
\sqrt{c_\sigma\over d_{h}}\text{diag}\left(\dot\sigma\left(\boldsymbol f^{(h)}_i(\boldsymbol x)\right)\right)\left(\boldsymbol W^{(h+1)}\right)^\intercal\boldsymbol b^{(h+1)}(\boldsymbol x)&h=L,L-1,\ldots,1
\end{cases}
$$

Then, we have

$$
{\partial f(\boldsymbol \theta(t),\boldsymbol x)\over \partial \boldsymbol W^{(h)}}=\boldsymbol b^{(h)}(\boldsymbol x)\left(\boldsymbol g^{(h-1)}(\boldsymbol x)\right)^\intercal
$$

Then,

$$
\begin{align}
\left\langle {\partial f(\boldsymbol \theta(t),\boldsymbol x)\over \partial \boldsymbol W^{(h)}}, {\partial f(\boldsymbol \theta(t),\boldsymbol x’)\over \partial \boldsymbol W^{(h)}} \right\rangle &=\left\langle \boldsymbol b^{(h)}(\boldsymbol x)\left(\boldsymbol g^{(h-1)}(\boldsymbol x)\right)^\intercal, \boldsymbol b^{(h)}(\boldsymbol x’)\left(\boldsymbol g^{(h-1)}(\boldsymbol x’)\right)^\intercal \right\rangle\\
&=\mathrm{Tr}\left( \boldsymbol g^{(h-1)}(\boldsymbol x)\left(\boldsymbol b^{(h)}(\boldsymbol x)\right)^\intercal \boldsymbol b^{(h)}(\boldsymbol x’)\left(\boldsymbol g^{(h-1)}(\boldsymbol x’)\right)^\intercal \right)\\
&=\mathrm{Tr}\left( \left(\boldsymbol b^{(h)}(\boldsymbol x)\right)^\intercal \boldsymbol b^{(h)}(\boldsymbol x’)\left(\boldsymbol g^{(h-1)}(\boldsymbol x’)\right)^\intercal\boldsymbol g^{(h-1)}(\boldsymbol x) \right)\\
&=\left\langle \boldsymbol b^{(h)}(\boldsymbol x), \boldsymbol b^{(h)}(\boldsymbol x’) \right\rangle\cdot\left\langle \boldsymbol g^{(h-1)}(\boldsymbol x’),\boldsymbol g^{(h-1)}(\boldsymbol x) \right\rangle
\end{align}
$$

We have $\left\langle \boldsymbol g^{(h-1)}(\boldsymbol x’),\boldsymbol g^{(h-1)}(\boldsymbol x) \right\rangle=\Sigma^{(h-1)}(\boldsymbol x,\boldsymbol x’)$

$$
\begin{align}
&\left\langle \boldsymbol b^{(h)}(\boldsymbol x), \boldsymbol b^{(h)}(\boldsymbol x’) \right\rangle\\
=&\left\langle \sqrt{c_\sigma\over d_{h}}\text{diag}\left(\dot\sigma\left(\boldsymbol f^{(h)}_i(\boldsymbol x)\right)\right)\left(\boldsymbol W^{(h+1)}\right)^\intercal\boldsymbol b^{(h+1)}(\boldsymbol x),\sqrt{c_\sigma\over d_{h}}\text{diag}\left(\dot\sigma\left(\boldsymbol f^{(h)}_i(\boldsymbol x’)\right)\right)\left(\boldsymbol W^{(h+1)}\right)^\intercal\boldsymbol b^{(h+1)}(\boldsymbol x’) \right\rangle\\
=&{c_\sigma\over d_{h}}\sum_i \left(\dot\sigma\left(\boldsymbol f^{(h)}_i(\boldsymbol x)\right)\sum_j \left(W^{(h+1)}_{ij}\right)^\intercal b^{(h+1)}_j(\boldsymbol x)\right)\left(\dot\sigma\left(\boldsymbol f^{(h)}_i(\boldsymbol x’)\right)\sum_k \left(W^{(h+1)}_{ik}\right)^\intercal b^{(h+1)}_k(\boldsymbol x’)\right)\\
=&{c_\sigma\over d_{h}}\sum_i \dot\sigma\left(\boldsymbol f^{(h)}_i(\boldsymbol x)\right)\dot\sigma\left(\boldsymbol f^{(h)}_i(\boldsymbol x’)\right) \sum_j \sum_k \left(W^{(h+1)}_{ij}\right)^\intercal \left(W^{(h+1)}_{ik}\right)^\intercal b^{(h+1)}_j(\boldsymbol x) b^{(h+1)}_k(\boldsymbol x’)\\
\approx& {c_\sigma\over d_{h}}\sum_i \dot\sigma\left(\boldsymbol f^{(h)}_i(\boldsymbol x)\right)\dot\sigma\left(\boldsymbol f^{(h)}_i(\boldsymbol x’)\right) \sum_j b^{(h+1)}_j(\boldsymbol x) b^{(h+1)}_j(\boldsymbol x’)\\
&\left(\boldsymbol W^{(h+1)}\text{ and }\boldsymbol b^{(h+1)}(\boldsymbol x)\text{ could be seen as independent}\right)\\
=& \left\langle\text{diag}\left(\dot\sigma\left(\boldsymbol f^{(h)}_i(\boldsymbol x)\right)\right),\text{diag}\left(\dot\sigma\left(\boldsymbol f^{(h)}_i(\boldsymbol x’)\right)\right) \right\rangle \cdot \left\langle\boldsymbol b^{(h+1)}(\boldsymbol x),\boldsymbol b^{(h+1)}(\boldsymbol x’)\right\rangle\\
=& \dot\Sigma^{(h)}(\boldsymbol x,\boldsymbol x’) \left\langle\boldsymbol b^{(h+1)}(\boldsymbol x),\boldsymbol b^{(h+1)}(\boldsymbol x’)\right\rangle\\
\end{align}
$$

where $\dot\Sigma^{(h)}(\boldsymbol x,\boldsymbol x’)=c_{\sigma} \underset{(u, v) \sim \mathcal{N}\left(\mathbf{0}, \boldsymbol{\Lambda}^{(h)}\right)}{\mathbb{E}}[\dot{\sigma}(u) \dot{\sigma}(v)]$

By induction, we have

$$\left\langle \boldsymbol b^{(h)}(\boldsymbol x), \boldsymbol b^{(h)}(\boldsymbol x’) \right\rangle=\prod_{h’=h}^{L+1}\dot\Sigma^{(h’)}(\boldsymbol x,\boldsymbol x’)$$

Finally,

$$
\begin{align}
\Theta^{(L)}(\boldsymbol x,\boldsymbol x’)&=\left\langle {\partial f(\boldsymbol \theta(t),\boldsymbol x)\over \partial \boldsymbol\theta}, {\partial f(\boldsymbol \theta(t),\boldsymbol x’)\over \partial \boldsymbol\theta} \right\rangle\\
&=\sum_{h=1}^{L+1}\left\langle {\partial f(\boldsymbol \theta(t),\boldsymbol x)\over \partial \boldsymbol W^{(h)}}, {\partial f(\boldsymbol \theta(t),\boldsymbol x’)\over \partial \boldsymbol W^{(h)}} \right\rangle\\
&=\sum_{h=1}^{L+1}\left(\Sigma^{(h-1)}(\boldsymbol x,\boldsymbol x’)\prod_{h’=h}^{L+1}\dot\Sigma^{(h’)}(\boldsymbol x,\boldsymbol x’)\right)
\end{align}
$$ $\blacksquare$

Convergence to NKT

Fix $\epsilon>0$ and $\delta \in(0,1) .$ Suppose $\sigma(z)=\max (0, z)$ and $\min_{h \in[L]} d_{h} \geq \Omega\left(\frac{L^{14}}{\epsilon^{4}} \log (L / \delta)\right) .$ Then for any inputs $\boldsymbol{x}, \boldsymbol{x}^{\prime} \in \mathbb{R}^{d_{0}}$ such that $|\boldsymbol{x}| \leq 1,\left|\boldsymbol{x}^{\prime}\right| \leq 1$, with probability at least $1-\delta$ we have:
$$
\left|\left\langle\frac{\partial f(\boldsymbol{\theta}, \boldsymbol{x})}{\partial \boldsymbol{\theta}}, \frac{\partial f\left(\boldsymbol{\theta}, \boldsymbol{x}^{\prime}\right)}{\partial \boldsymbol{\theta}}\right\rangle-\Theta^{(L)}\left(\boldsymbol{x}, \boldsymbol{x}^{\prime}\right)\right| \leq \epsilon
$$

TBD

NKT Dynamics

Since,

$${d\boldsymbol u\over dt}=-\boldsymbol H^*(\boldsymbol u(t) -\boldsymbol y)$$

Suppose we decompose the $\boldsymbol H^*$ by its eigenvalues,
$$\boldsymbol H^* \boldsymbol v_i=\lambda_i \boldsymbol v_i$$

Then, we could consider the dynamics of $\boldsymbol u(t) $ on each eigenvector separately,
$$
\begin{aligned}
{d\boldsymbol v_i^\top \boldsymbol u\over dt}&=-\boldsymbol v_i^\top \boldsymbol H^*(\boldsymbol u(t) -\boldsymbol y)\\
&=-\boldsymbol \lambda_i\left( \boldsymbol v_i^\top(\boldsymbol u(t) -\boldsymbol y)\right)
\end{aligned}
$$

Solving this one-dimensional ODE could yield
$$\boldsymbol v_i^\top(\boldsymbol u(t) -\boldsymbol y)=e^{-\lambda_it}\left( \boldsymbol v_i^\top(\boldsymbol u(0) -\boldsymbol y)\right)$$

Thus, we have to assume $\lambda_i>0$ to reach zero training error. This means $\boldsymbol H^*$ have to be positive-definite.

References

[1] Arora, Sanjeev, Simon S. Du, Wei Hu, Zhiyuan Li, Russ R. Salakhutdinov, and Ruosong Wang. “On Exact Computation with an Infinitely Wide Neural Net.” Advances in Neural Information Processing Systems 32 (2019): 8141-8150.

[2] STATS 403 Lecture Slides. Duke Kunshan University.