Skip to content

Commit d6b7c3c

Browse files
yadav-sachinpre-commit-ci[bot]dfm
authored
Adding more details to deep kernel learning example (#70)
* add details to deep kernel example * remove blackcellmagic cell * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add news fragment * add more description * show matern-3/2 on transformed features * add more description * [pre-commit.ci] pre-commit autoupdate (#71) updates: - [github.com/hadialqattan/pycln: v1.2.4 → v1.2.5](hadialqattan/pycln@v1.2.4...v1.2.5) - [github.com/pre-commit/mirrors-mypy: v0.931 → v0.940](pre-commit/mirrors-mypy@v0.931...v0.940) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * try simplify intuition * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add smoothing assumption * some edits Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Dan F-M <foreman.mackey@gmail.com>
1 parent 1861232 commit d6b7c3c

File tree

2 files changed

+172
-33
lines changed

2 files changed

+172
-33
lines changed

docs/tutorials/transforms.ipynb

Lines changed: 169 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -87,13 +87,15 @@
8787
"metadata": {},
8888
"source": [
8989
"Then we will fit this model using a model similar to the one described in {ref}`modeling-flax`, except our kernel will include a custom {class}`tinygp.kernels.Transform` that will pass the input coordinates through a (small) neural network before passing them into a {class}`tinygp.kernels.Matern32` kernel.\n",
90-
"Otherwise, the model and optimization procedure are similar to the ones used in {ref}`modeling-flax`."
90+
"Otherwise, the model and optimization procedure are similar to the ones used in {ref}`modeling-flax`.\n",
91+
"\n",
92+
"We compare the performance of the Deep Matern-3/2 kernel (a {class}`tinygp.kernels.Matern32` kernel, with custom neural network transform) to the performance of the same kernel without the transform. The untransformed model doesn't have the capacity to capture our simulated step function, but our transformed model does. In our transformed model, the hyperparameters of our kernel now include the weights of our neural network transform, and we learn those simultaneously with the length scale and amplitude of the `Matern32` kernel."
9193
]
9294
},
9395
{
9496
"cell_type": "code",
9597
"execution_count": null,
96-
"id": "e0065dea-379a-4e0d-8cf2-f460c8126a5f",
98+
"id": "94938ebe",
9799
"metadata": {},
98100
"outputs": [],
99101
"source": [
@@ -102,11 +104,49 @@
102104
"import jax.numpy as jnp\n",
103105
"import flax.linen as nn\n",
104106
"from flax.linen.initializers import zeros\n",
105-
"from tinygp import kernels, transforms, GaussianProcess\n",
106-
"\n",
107+
"from tinygp import kernels, transforms, GaussianProcess"
108+
]
109+
},
110+
{
111+
"cell_type": "code",
112+
"execution_count": null,
113+
"id": "21e63e64",
114+
"metadata": {
115+
"tags": [
116+
"hide-cell"
117+
]
118+
},
119+
"outputs": [],
120+
"source": [
121+
"class Matern32Loss(nn.Module):\n",
122+
" @nn.compact\n",
123+
" def __call__(self, x, y, t):\n",
124+
" # Set up a typical Matern-3/2 kernel\n",
125+
" log_sigma = self.param(\"log_sigma\", zeros, ())\n",
126+
" log_rho = self.param(\"log_rho\", zeros, ())\n",
127+
" log_jitter = self.param(\"log_jitter\", zeros, ())\n",
128+
" base_kernel = jnp.exp(2 * log_sigma) * kernels.Matern32(\n",
129+
" jnp.exp(log_rho)\n",
130+
" )\n",
107131
"\n",
108-
"# Define a small neural network used to non-linearly transform the input data in our model\n",
132+
" # Evaluate and return the GP negative log likelihood as usual\n",
133+
" gp = GaussianProcess(\n",
134+
" base_kernel, x[:, None], diag=noise**2 + jnp.exp(2 * log_jitter)\n",
135+
" )\n",
136+
" log_prob, gp_cond = gp.condition(y, t[:, None])\n",
137+
" return -log_prob, (gp_cond.loc, gp_cond.variance)"
138+
]
139+
},
140+
{
141+
"cell_type": "code",
142+
"execution_count": null,
143+
"id": "e0065dea-379a-4e0d-8cf2-f460c8126a5f",
144+
"metadata": {},
145+
"outputs": [],
146+
"source": [
109147
"class Transformer(nn.Module):\n",
148+
" \"\"\"A small neural network used to non-linearly transform the input data\"\"\"\n",
149+
"\n",
110150
" @nn.compact\n",
111151
" def __call__(self, x):\n",
112152
" x = nn.Dense(features=15)(x)\n",
@@ -117,7 +157,7 @@
117157
" return x\n",
118158
"\n",
119159
"\n",
120-
"class GPLoss(nn.Module):\n",
160+
"class DeepLoss(nn.Module):\n",
121161
" @nn.compact\n",
122162
" def __call__(self, x, y, t):\n",
123163
" # Set up a typical Matern-3/2 kernel\n",
@@ -128,56 +168,152 @@
128168
" jnp.exp(log_rho)\n",
129169
" )\n",
130170
"\n",
131-
" # Define a custom transform to pass the input coordinates through our `Transformer`\n",
132-
" # network from above\n",
171+
" # Define a custom transform to pass the input coordinates through our\n",
172+
" # `Transformer` network from above\n",
133173
" transform = Transformer()\n",
134174
" kernel = transforms.Transform(transform, base_kernel)\n",
135175
"\n",
136-
" # Evaluate and return the GP negative log likelihood as usual\n",
176+
" # Evaluate and return the GP negative log likelihood as usual with the\n",
177+
" # transformed features\n",
137178
" gp = GaussianProcess(\n",
138179
" kernel, x[:, None], diag=noise**2 + jnp.exp(2 * log_jitter)\n",
139180
" )\n",
140181
" log_prob, gp_cond = gp.condition(y, t[:, None])\n",
141-
" return -log_prob, (gp_cond.loc, gp_cond.variance)\n",
182+
"\n",
183+
" # We return the loss, the conditional mean and variance, and the\n",
184+
" # transformed input parameters\n",
185+
" return (\n",
186+
" -log_prob,\n",
187+
" (gp_cond.loc, gp_cond.variance),\n",
188+
" (transform(x[:, None]), transform(t[:, None])),\n",
189+
" )\n",
142190
"\n",
143191
"\n",
144192
"# Define and train the model\n",
145-
"def loss(params):\n",
146-
" return model.apply(params, x, y, t)[0]\n",
193+
"def loss_func(model):\n",
194+
" def loss(params):\n",
195+
" return model.apply(params, x, y, t)[0]\n",
147196
"\n",
197+
" return loss\n",
148198
"\n",
149-
"model = GPLoss()\n",
150-
"params = model.init(jax.random.PRNGKey(1234), x, y, t)\n",
151-
"tx = optax.sgd(learning_rate=1e-4)\n",
152-
"opt_state = tx.init(params)\n",
153-
"loss_grad_fn = jax.jit(jax.value_and_grad(loss))\n",
154-
"for i in range(1000):\n",
155-
" loss_val, grads = loss_grad_fn(params)\n",
156-
" updates, opt_state = tx.update(grads, opt_state)\n",
157-
" params = optax.apply_updates(params, updates)\n",
158199
"\n",
200+
"models_list, params_list = [], []\n",
201+
"loss_vals = {}\n",
159202
"# Plot the results and compare to the true model\n",
160-
"plt.figure()\n",
161-
"mu, var = model.apply(params, x, y, t)[1]\n",
162-
"plt.plot(t, 2 * (t > 0) - 1, \"k\", lw=1, label=\"truth\")\n",
163-
"plt.plot(x, y, \".k\", label=\"data\")\n",
164-
"plt.plot(t, mu)\n",
165-
"plt.fill_between(\n",
166-
" t, mu + np.sqrt(var), mu - np.sqrt(var), alpha=0.5, label=\"model\"\n",
167-
")\n",
168-
"plt.xlim(-1.5, 1.5)\n",
169-
"plt.ylim(-1.3, 1.3)\n",
170-
"plt.xlabel(\"x\")\n",
171-
"plt.ylabel(\"y\")\n",
203+
"fig, ax = plt.subplots(ncols=2, sharey=True, figsize=(9, 3))\n",
204+
"for it, (model_name, model) in enumerate(\n",
205+
" zip(\n",
206+
" [\"Deep\", \"Matern32\"],\n",
207+
" [DeepLoss(), Matern32Loss()],\n",
208+
" )\n",
209+
"):\n",
210+
" loss_vals[it] = []\n",
211+
" params = model.init(jax.random.PRNGKey(1234), x, y, t)\n",
212+
" tx = optax.sgd(learning_rate=1e-4)\n",
213+
" opt_state = tx.init(params)\n",
214+
"\n",
215+
" loss = loss_func(model)\n",
216+
" loss_grad_fn = jax.jit(jax.value_and_grad(loss))\n",
217+
" for i in range(1000):\n",
218+
" loss_val, grads = loss_grad_fn(params)\n",
219+
" updates, opt_state = tx.update(grads, opt_state)\n",
220+
" params = optax.apply_updates(params, updates)\n",
221+
" loss_vals[it].append(loss_val)\n",
222+
"\n",
223+
" mu, var = model.apply(params, x, y, t)[1]\n",
224+
" ax[it].plot(t, 2 * (t > 0) - 1, \"k\", lw=1, label=\"truth\")\n",
225+
" ax[it].plot(x, y, \".k\", label=\"data\")\n",
226+
" ax[it].plot(t, mu)\n",
227+
" ax[it].fill_between(\n",
228+
" t, mu + np.sqrt(var), mu - np.sqrt(var), alpha=0.5, label=\"model\"\n",
229+
" )\n",
230+
" ax[it].set_xlim(-1.5, 1.5)\n",
231+
" ax[it].set_ylim(-1.3, 1.3)\n",
232+
" ax[it].set_xlabel(\"x\")\n",
233+
" ax[it].set_ylabel(\"y\")\n",
234+
" ax[it].set_title(model_name)\n",
235+
" _ = ax[it].legend()\n",
236+
"\n",
237+
" models_list.append(model)\n",
238+
" params_list.append(params)"
239+
]
240+
},
241+
{
242+
"cell_type": "markdown",
243+
"id": "bb4d5f08",
244+
"metadata": {},
245+
"source": [
246+
"The untransformed `Matern32` model suffers from over-smoothing at the discontinuity, and poor extrapolation performance.\n",
247+
"The `Deep` model extrapolates well and captures the discontinuity reliably.\n",
248+
"\n",
249+
"We can compare the training loss (negative log likelihood) traces for these two models:"
250+
]
251+
},
252+
{
253+
"cell_type": "code",
254+
"execution_count": null,
255+
"id": "feff3a28",
256+
"metadata": {},
257+
"outputs": [],
258+
"source": [
259+
"fig = plt.plot()\n",
260+
"plt.plot(loss_vals[0], label=\"Deep\")\n",
261+
"plt.plot(loss_vals[1], label=\"Matern32\")\n",
262+
"plt.ylabel(\"Loss\")\n",
263+
"plt.xlabel(\"Training Iterations\")\n",
172264
"_ = plt.legend()"
173265
]
174266
},
267+
{
268+
"cell_type": "markdown",
269+
"id": "5692e918",
270+
"metadata": {},
271+
"source": [
272+
"To inspect what the transformed model is doing under the hood, we can plot the functional form of the transformation, as well as the transformed values of our input coordinates: "
273+
]
274+
},
175275
{
176276
"cell_type": "code",
177277
"execution_count": null,
178278
"id": "a281b035-513a-4215-87fd-1a83b52ebd79",
179279
"metadata": {},
180280
"outputs": [],
281+
"source": [
282+
"x_transform, t_transform = models_list[0].apply(params_list[0], x, y, t)[2]\n",
283+
"\n",
284+
"fig = plt.figure()\n",
285+
"plt.plot(t, t_transform, \"k\")\n",
286+
"plt.xlim(-1.5, 1.5)\n",
287+
"plt.ylim(-1.3, 1.3)\n",
288+
"plt.xlabel(\"input data; x\")\n",
289+
"plt.ylabel(\"transformed data; x'\")\n",
290+
"\n",
291+
"fig, ax = plt.subplots(ncols=2, sharey=True, figsize=(9, 3))\n",
292+
"for it, (fig_title, feature_input, x_label) in enumerate(\n",
293+
" zip([\"Input Data\", \"Transformed Data\"], [x, x_transform], [\"x\", \"x'\"])\n",
294+
"):\n",
295+
" ax[it].plot(feature_input, y, \".k\")\n",
296+
" ax[it].set_xlim(-1.5, 1.5)\n",
297+
" ax[it].set_ylim(-1.3, 1.3)\n",
298+
" ax[it].set_title(fig_title)\n",
299+
" ax[it].set_xlabel(x_label)\n",
300+
" ax[it].set_ylabel(\"y\")"
301+
]
302+
},
303+
{
304+
"cell_type": "markdown",
305+
"id": "673435d3",
306+
"metadata": {},
307+
"source": [
308+
"The neural network transforms the input feature into a step function like data (as shown in the figures above) before feeding to the base kernel, making it better suited than the baseline model for this data."
309+
]
310+
},
311+
{
312+
"cell_type": "code",
313+
"execution_count": null,
314+
"id": "1d805ca0",
315+
"metadata": {},
316+
"outputs": [],
181317
"source": []
182318
}
183319
],

news/70.doc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Add more details to Deep Kernel learning tutorial,
2+
showing comparison with Matern-3/2 kernel
3+
and the transformed features.

0 commit comments

Comments
 (0)