{ "cells": [ { "cell_type": "markdown", "id": "7fbf07a8", "metadata": {}, "source": [ "(linalg_tutorial)=\n", "# Intro to the linear algebra module\n", "Most of the linear algebra module are wrappers with very few lines and an API nearly equal to their numpy counterpart. In general, the only thing you need to do is pass the input DataArray and indicate which dimensions\n", "correspond to the matrices. There are only a couple exceptions which have their own section." ] }, { "cell_type": "code", "execution_count": 1, "id": "f634cd80", "metadata": {}, "outputs": [], "source": [ "import xarray_einstats\n", "from xarray_einstats.tutorial import generate_matrices_dataarray" ] }, { "cell_type": "markdown", "id": "60f466a1", "metadata": {}, "source": [ "We start by generating syntetic data to work with:" ] }, { "cell_type": "code", "execution_count": 2, "id": "2d6947e2", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.DataArray (batch: 10, experiment: 3, dim: 4, dim2: 4)> Size: 4kB\n",
       "0.7075 1.025 0.5685 0.8951 0.2065 3.384 ... 1.239 0.4527 0.5749 0.4766 0.859\n",
       "Dimensions without coordinates: batch, experiment, dim, dim2
" ], "text/plain": [ " Size: 4kB\n", "0.7075 1.025 0.5685 0.8951 0.2065 3.384 ... 1.239 0.4527 0.5749 0.4766 0.859\n", "Dimensions without coordinates: batch, experiment, dim, dim2" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "da = generate_matrices_dataarray(7)\n", "da" ] }, { "cell_type": "markdown", "id": "6d3974b1", "metadata": {}, "source": [ "The data represents a collection of matrices. `dim` and `dim2` indicate the matrix dimensions, the whole array is 4d, with 30 matrices in total from 10 batches and 3 experiments. \n", "\n", "(linalg_tutorial/general)=\n", "## General linalg functions\n", "You can get the trace of all 30 matrices in a single line, you only need the input DataArray and the dimensions corresponding to the matrices:" ] }, { "cell_type": "code", "execution_count": 3, "id": "83b90669", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.DataArray (batch: 10, experiment: 3)> Size: 240B\n",
       "4.854 4.74 4.457 2.637 2.79 3.163 1.998 ... 2.804 4.58 2.888 4.936 5.983 4.07\n",
       "Dimensions without coordinates: batch, experiment
" ], "text/plain": [ " Size: 240B\n", "4.854 4.74 4.457 2.637 2.79 3.163 1.998 ... 2.804 4.58 2.888 4.936 5.983 4.07\n", "Dimensions without coordinates: batch, experiment" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "da.linalg.trace(dims=[\"dim\", \"dim2\"])" ] }, { "cell_type": "markdown", "id": "94683b64", "metadata": {}, "source": [ "The main feature of the wrappers is that they know what is the expected shape of the output, you don't need to take care of it. See how the inverse which doesn't reduce the matrix dimension can be called with the exact same arguments." ] }, { "cell_type": "code", "execution_count": 4, "id": "105a9335", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.DataArray (batch: 10, experiment: 3, dim: 4, dim2: 4)> Size: 4kB\n",
       "11.26 -2.363 -10.84 -0.2744 10.99 -2.017 ... -3.444 0.7703 0.316 0.01949 -1.162\n",
       "Dimensions without coordinates: batch, experiment, dim, dim2
" ], "text/plain": [ " Size: 4kB\n", "11.26 -2.363 -10.84 -0.2744 10.99 -2.017 ... -3.444 0.7703 0.316 0.01949 -1.162\n", "Dimensions without coordinates: batch, experiment, dim, dim2" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "da.linalg.inv(dims=[\"dim\", \"dim2\"])" ] }, { "cell_type": "markdown", "id": "4c2d4c8a", "metadata": {}, "source": [ "Even a qr decomposition which returns multiple matrices (which could even have different shapes) needs only these two arguments to work. (batched qr decomposition requires numpy>=1.22)" ] }, { "cell_type": "code", "execution_count": 5, "id": "f12f2ccb", "metadata": {}, "outputs": [], "source": [ "q, r = da.linalg.qr(dims=[\"dim\", \"dim2\"])" ] }, { "cell_type": "code", "execution_count": 6, "id": "78c4b53f", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.DataArray (batch: 10, experiment: 3, dim: 4, dim2: 4)> Size: 4kB\n",
       "-0.5452 0.01652 -0.5624 -0.6214 -0.1592 ... -0.3322 -0.4013 0.2607 0.8128\n",
       "Dimensions without coordinates: batch, experiment, dim, dim2
" ], "text/plain": [ " Size: 4kB\n", "-0.5452 0.01652 -0.5624 -0.6214 -0.1592 ... -0.3322 -0.4013 0.2607 0.8128\n", "Dimensions without coordinates: batch, experiment, dim, dim2" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "q" ] }, { "cell_type": "code", "execution_count": 7, "id": "1eb487a1", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.DataArray (batch: 10, experiment: 3, dim: 4, dim2: 4)> Size: 4kB\n",
       "-1.298 -1.975 -1.858 -1.228 0.0 -3.137 ... -0.4307 1.052 0.0 0.0 0.0 -0.6995\n",
       "Dimensions without coordinates: batch, experiment, dim, dim2
" ], "text/plain": [ " Size: 4kB\n", "-1.298 -1.975 -1.858 -1.228 0.0 -3.137 ... -0.4307 1.052 0.0 0.0 0.0 -0.6995\n", "Dimensions without coordinates: batch, experiment, dim, dim2" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "r" ] }, { "cell_type": "markdown", "id": "cbf3de67", "metadata": {}, "source": [ ":::{tip}\n", "Do you always follow the same convention to name your matrix dimensions and feel that even having to repeat that is\n", "unnecessary? Take a look at {func}`xarray_einstats.linalg.get_default_dims` to see how to modify the default dims used by the linalg wrappers\n", ":::" ] }, { "cell_type": "markdown", "id": "0a136362", "metadata": {}, "source": [ "(linalg_tutorial/matmul)=\n", "## matmul: 1st exception\n", "The general representation of a matrix multiplication is:\n", "\n", "$$\n", "\\mathcal{M}_1^{N\\times K} * \\mathcal{M}_2^{K\\times M} = \\mathcal{M}^{N\\times M}\n", "$$ (eq:matmul)\n", "\n", "There are conceptually 3 dimensions involved in the operation because the 2nd dimension of $\\mathcal{M}_1$\n", "needs to be the same as the 1st dimension of $\\mathcal{M}_2$. Moreover, when working with square matrices, $N==M==K$\n", "and there is only 1 dimension.\n", "\n", "When working with xarray however, there can't be repeated dimension names, so as we have already seen, conceptually equivalent dimensions will have potentially different names, i.e. `dim` and `dim2`.\n", "\n", "Taking all of this into account, `matmul`'s `dims` argument supports indicating the dimensions in 3 different ways. The following table summarizes the inputs `dims` accepts and how they are interpreted:\n", "\n", "| `dims` | dim_a1 | dim_a2 | dim_b1 | dim_b2 |\n", "|--------|---------|--------|--------|--------|\n", "| `[dim1, dim2]`| dim1 | dim2 | dim1 | dim2 |\n", "| `[dim1, dim2, dim3]` | dim1 | dim2 | dim2 | dim3 |\n", "| `[[dim_a1, dim_a2], [dim_b1, dim_b2]]` | dim_a1 | dim_a2 | dim_b1 | dim_b2 |\n", "\n", "where `dim_a1, dim_a2` are the matrix dimensions of the first matrix, and `dim_b#` are the matrix dimensions\n", "of the 2nd matrix. Like in {eq}`eq:matmul`, **the dimensions present in the output are `dim_a1, dim_b2`.**\n", "\n", "### List of two elements\n", "\n", "This first example uses square matrices, so when doing a matrix multiplication, the two dimensions are common in both inputs. You only need a list with two strings to indicate how to perform the multiplication:" ] }, { "cell_type": "code", "execution_count": 8, "id": "8e2325d2", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.DataArray (batch: 10, experiment: 3, dim: 4, dim2: 4)> Size: 4kB\n",
       "1.845 5.326 2.407 3.89 3.378 14.68 5.449 ... 5.586 6.55 1.279 1.373 1.791 2.658\n",
       "Dimensions without coordinates: batch, experiment, dim, dim2
" ], "text/plain": [ " Size: 4kB\n", "1.845 5.326 2.407 3.89 3.378 14.68 5.449 ... 5.586 6.55 1.279 1.373 1.791 2.658\n", "Dimensions without coordinates: batch, experiment, dim, dim2" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from xarray_einstats import linalg\n", "\n", "linalg.matmul(da, da, dims=[\"dim\", \"dim2\"])" ] }, { "cell_type": "markdown", "id": "b4941ffb", "metadata": {}, "source": [ "### List of three elements" ] }, { "cell_type": "markdown", "id": "eabf865f", "metadata": {}, "source": [ "However, the input matrices for matrix multiplication might not be square or might not have the exact same dimension names. As we have seen, what is necessary if for the 2nd dimension of the 1st matrix to match with the 1st dimension of the 2nd matrix. This 3 element list of dimensions is arguable the most common way to specify matrix multiplications.\n", "\n", "You could interpret the DataArray as a collection of matrices of dimension `batch, experiment`, or with `experiment, dim2` indicating the matrices. Those two collections of matrices are valid inputs for matrix multiplication. \n", "\n", "As there is still one that need to match, `matmul` can also take a list of 3 dimensions:" ] }, { "cell_type": "code", "execution_count": 9, "id": "4e3de04f", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.DataArray (dim: 4, dim2_bis: 4, batch_bis: 10, batch: 10, dim2: 4)> Size: 51kB\n",
       "10.79 3.926 1.503 3.986 0.1886 0.1844 ... 1.289 4.187 5.251 3.372 2.81 13.1\n",
       "Dimensions without coordinates: dim, dim2_bis, batch_bis, batch, dim2
" ], "text/plain": [ " Size: 51kB\n", "10.79 3.926 1.503 3.986 0.1886 0.1844 ... 1.289 4.187 5.251 3.372 2.81 13.1\n", "Dimensions without coordinates: dim, dim2_bis, batch_bis, batch, dim2" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "linalg.matmul(da, da, dims=[\"batch\", \"experiment\", \"dim2\"], out_append=\"_bis\")" ] }, { "cell_type": "markdown", "id": "fdb15ef3", "metadata": {}, "source": [ "Here, `batch` and `dim2` were matrix dimensions in one of the matrices and batch dimensions in the other. While this\n", "might not be very common, `xarray-einstats` check for dimensions that would end up being duplicated in the output and renames them if necessary using `out_append` to avoid collisions.\n", "\n", "A similar thing happens when both dim1 and dim3 have the same name:" ] }, { "cell_type": "code", "execution_count": 10, "id": "76d55d22", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.DataArray (dim: 4, dim2: 4, batch: 10, batch2: 10)> Size: 13kB\n",
       "10.79 0.1886 5.402 1.471 1.243 5.348 2.639 ... 3.462 3.618 11.21 9.47 4.187 13.1\n",
       "Dimensions without coordinates: dim, dim2, batch, batch2
" ], "text/plain": [ " Size: 13kB\n", "10.79 0.1886 5.402 1.471 1.243 5.348 2.639 ... 3.462 3.618 11.21 9.47 4.187 13.1\n", "Dimensions without coordinates: dim, dim2, batch, batch2" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "linalg.matmul(da, da, dims=[\"batch\", \"experiment\", \"batch\"])" ] }, { "cell_type": "markdown", "id": "796bec63", "metadata": {}, "source": [ "### List of 2 element lists\n", "The 3rd option is the more verbose and explicit, but still necessary to avoid the need for manual renamings before being able to multiply some matrices. \n", "\n", "To see how it works, you'll need a `db` object, with the same shape but different dimension names:" ] }, { "cell_type": "code", "execution_count": 11, "id": "b1628bc6", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.DataArray (batch: 10, experiment: 3, different_dim: 4, different_dim2: 4)> Size: 4kB\n",
       "0.7075 1.025 0.5685 0.8951 0.2065 3.384 ... 1.239 0.4527 0.5749 0.4766 0.859\n",
       "Dimensions without coordinates: batch, experiment, different_dim, different_dim2
" ], "text/plain": [ " Size: 4kB\n", "0.7075 1.025 0.5685 0.8951 0.2065 3.384 ... 1.239 0.4527 0.5749 0.4766 0.859\n", "Dimensions without coordinates: batch, experiment, different_dim, different_dim2" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "db = da.rename(dim=\"different_dim\", dim2=\"different_dim2\")\n", "db" ] }, { "cell_type": "markdown", "id": "7ee9b631", "metadata": {}, "source": [ "Now `da` and `db` are compatible and you might want to multiply them, after all, it's the same operation we did in the first `matmul` example (you can check the result if running the notebook). But given the name mismatch it wasn't possible to use the first nor second option:" ] }, { "cell_type": "code", "execution_count": 12, "id": "6a6880cd", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.DataArray (batch: 10, experiment: 3, dim: 4, different_dim2: 4)> Size: 4kB\n",
       "1.845 5.326 2.407 3.89 3.378 14.68 5.449 ... 5.586 6.55 1.279 1.373 1.791 2.658\n",
       "Dimensions without coordinates: batch, experiment, dim, different_dim2
" ], "text/plain": [ " Size: 4kB\n", "1.845 5.326 2.407 3.89 3.378 14.68 5.449 ... 5.586 6.55 1.279 1.373 1.791 2.658\n", "Dimensions without coordinates: batch, experiment, dim, different_dim2" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "linalg.matmul(da, db, dims=[[\"dim\", \"dim2\"], [\"different_dim\", \"different_dim2\"]])" ] }, { "cell_type": "markdown", "id": "fc2a2c4c", "metadata": {}, "source": [ "Whenever the dimension being multiplied/reduced doesn't have the same name in both matrices, you'll need to use this 2+2 dims specification. Like in the list of 3 elements case, `matmul` avoids name clashes:" ] }, { "cell_type": "code", "execution_count": 13, "id": "55e58fe9", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.DataArray (dim: 4, dim2: 4, experiment: 3, experiment2: 3)> Size: 1kB\n",
       "9.727 6.68 3.595 6.68 18.66 6.065 3.595 ... 10.81 36.08 8.181 3.233 8.181 14.77\n",
       "Dimensions without coordinates: dim, dim2, experiment, experiment2
" ], "text/plain": [ " Size: 1kB\n", "9.727 6.68 3.595 6.68 18.66 6.065 3.595 ... 10.81 36.08 8.181 3.233 8.181 14.77\n", "Dimensions without coordinates: dim, dim2, experiment, experiment2" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dc = da.rename(batch=\"batch_bis\")\n", "linalg.matmul(da, dc, dims=[[\"experiment\", \"batch\"], [\"batch_bis\", \"experiment\"]])" ] }, { "cell_type": "markdown", "id": "349126f9", "metadata": {}, "source": [ "(linalg_tutorial/einsum)=\n", "## einsum: 2nd and most notable exception\n", "`einsum` is a such a flexible function that it can even be intimidating. It can cover from `sum` operations, to {func}`xarray.dot` reductions and obviously some operations similar to `einops` which after all is inspired in einsum. \n", "\n", "The goal of this page is not to be an extensive nor in depth guide on einsum but to act as a small ladder from simple operations that you can do without einsum until reaching operations that are only possible with einsum. This will give you a good look into `xarray_einstats` unique version of `einsum` that works with named dimensions, you'll see how most einsum operations translate to our syntax. \n", "\n", "If you want to master einsum however, we direct you to {func}`numpy.einsum` documentation and the [einops](https://einops.rocks/) package. To ease a little bit your ability to follow the tutorial without needing to understand einsum beforehand, we provide the equivalent in non-einsum functions (which is often multiple operations) inside of toggle-able note boxes. But keep in mind that the goal of this section is not teaching how to use `einsum` but showing how to use\n", "`xarray_einstats` to perform einsum operations with named dimension names." ] }, { "cell_type": "code", "execution_count": 14, "id": "4ebd0d05", "metadata": {}, "outputs": [], "source": [ "from xarray_einstats import einsum\n", "import xarray as xr" ] }, { "cell_type": "markdown", "id": "987df7be", "metadata": {}, "source": [ "Start reducing the `experiment` dimension. Any ellipsis, broadcasting and transposition is handled by xarray and xarray-einstats. You only need to care about the dimensions you want to operate on. Use `[]` to indicate you want to reduce the dimension (or `->` in string syntax):" ] }, { "cell_type": "code", "execution_count": 15, "id": "3dd53554", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.DataArray (batch: 10, dim: 4, dim2: 4)> Size: 1kB\n",
       "4.487 3.158 0.9252 2.683 0.5319 3.799 ... 3.387 1.796 2.601 2.455 1.538 5.402\n",
       "Dimensions without coordinates: batch, dim, dim2
" ], "text/plain": [ " Size: 1kB\n", "4.487 3.158 0.9252 2.683 0.5319 3.799 ... 3.387 1.796 2.601 2.455 1.538 5.402\n", "Dimensions without coordinates: batch, dim, dim2" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "einsum([[\"experiment\"], []], da)\n", "einsum(\"experiment->\", da)" ] }, { "cell_type": "markdown", "id": "664dde4c", "metadata": {}, "source": [ "The same can be dome with multiple dimensions." ] }, { "cell_type": "code", "execution_count": 16, "id": "74565c10", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.DataArray (dim: 4, dim2: 4)> Size: 128B\n",
       "22.27 32.55 29.06 40.96 23.96 33.48 ... 25.27 29.59 34.97 20.57 34.89 30.26\n",
       "Dimensions without coordinates: dim, dim2
" ], "text/plain": [ " Size: 128B\n", "22.27 32.55 29.06 40.96 23.96 33.48 ... 25.27 29.59 34.97 20.57 34.89 30.26\n", "Dimensions without coordinates: dim, dim2" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "einsum([[\"batch\", \"experiment\"], []], da)\n", "einsum(\"batch experiment->\", da)" ] }, { "cell_type": "markdown", "id": "668189ed", "metadata": {}, "source": [ ":::{note}\n", ":class: dropdown\n", "\n", "These two calls are respectively equivalent to \n", "\n", "```\n", "da.sum(\"experiment\")\n", "da.sum((\"batch\", \"experiment\"))\n", "```\n", ":::\n", "\n", "`einsum` also takes multiple outputs. In those cases, if there are repeated dimensions in the expressions\n", "corresponding to different inputs and we want to reduce all of them, the output expression can be omitted, just like\n", "you'd do with `numpy.einsum`." ] }, { "cell_type": "code", "execution_count": 17, "id": "28ca15d9", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.DataArray (batch: 10, dim: 4, dim2: 4)> Size: 1kB\n",
       "10.79 3.543 0.4447 2.399 0.111 11.58 10.95 ... 5.104 1.799 2.513 3.052 0.79 13.1\n",
       "Dimensions without coordinates: batch, dim, dim2
" ], "text/plain": [ " Size: 1kB\n", "10.79 3.543 0.4447 2.399 0.111 11.58 10.95 ... 5.104 1.799 2.513 3.052 0.79 13.1\n", "Dimensions without coordinates: batch, dim, dim2" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "einsum([[\"experiment\"], [\"experiment\"]], da, da)\n", "einsum(\"experiment,experiment\", da, da)" ] }, { "cell_type": "markdown", "id": "e1a20ca2", "metadata": {}, "source": [ ":::{note}\n", ":class: dropdown\n", "\n", "This call combines a product and a summation, and has two equivalents. One using `xarray.dot` (also quite einsum-like), another in simple mathematical operations:\n", "\n", "```\n", "xr.dot(da, da, dims\"experiment\")\n", "(da * da).sum(\"experiment\")\n", "```\n", ":::\n", "\n", "When there are no repeated indexes between inputs, then the results of _implicit_ and _explicit_ mode are different, again, just like in `numpy.einsum`. After all, `xarray_einstats.einsum` is an interface to it that uses\n", "dimension names and needs no ellipsis.\n", "\n", "**Implicit mode:**" ] }, { "cell_type": "code", "execution_count": 18, "id": "b9645aac", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.DataArray (dim: 4, dim2: 4, batch: 10, experiment: 3)> Size: 4kB\n",
       "33.15 44.26 22.52 1.318 1.76 0.8951 ... 19.52 36.93 18.42 42.62 80.64 40.23\n",
       "Dimensions without coordinates: dim, dim2, batch, experiment
" ], "text/plain": [ " Size: 4kB\n", "33.15 44.26 22.52 1.318 1.76 0.8951 ... 19.52 36.93 18.42 42.62 80.64 40.23\n", "Dimensions without coordinates: dim, dim2, batch, experiment" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "einsum([[\"experiment\"], [\"batch\"]], da, da)\n", "einsum(\"experiment,batch\", da, da)" ] }, { "cell_type": "markdown", "id": "5304c907", "metadata": {}, "source": [ "This call no longer has a single operation equivalent. Here we are performing multiple summations\n", "and multiplications. And also reordering the dimensions. The first time they are encountered, dimensions are mapped to a single letter (the input accepted by einsum) in reverse alphabetical order. After that, the saved mapping is used. Therefore, following the `xarray.apply_ufunc` convention, the default order of the dimensions is the following:\n", "1. All the ommitted dimensions _in the order they appear in the inputs_\n", "2. All dimensions present in the expressions _in the **inverse** order they appear in the expression_ for the first time\n", "\n", "Thus, the output has `dim` and `dim2` first as they are not present in the expressions, then comes `batch` and finally `experiment` in the exact inverted order they appear in the expression.\n", "\n", ":::{note}\n", ":class: dropdown\n", "\n", "Even though this computation no longer has a single function/method equivalent, it does have a multiple\n", "operation equivalent:\n", "\n", "```python\n", "(da.sum(\"experiment\") * da.sum(\"batch\")).transpose(..., \"batch\", \"experiment\")\n", "```\n", ":::\n", "\n", "**Explicit mode:**" ] }, { "cell_type": "code", "execution_count": 19, "id": "c30ec8b3", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.DataArray (dim: 4, dim2: 4)> Size: 128B\n",
       "496.0 1.06e+03 844.2 1.678e+03 573.9 ... 875.4 1.223e+03 423.1 1.218e+03 915.8\n",
       "Dimensions without coordinates: dim, dim2
" ], "text/plain": [ " Size: 128B\n", "496.0 1.06e+03 844.2 1.678e+03 573.9 ... 875.4 1.223e+03 423.1 1.218e+03 915.8\n", "Dimensions without coordinates: dim, dim2" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "einsum([[\"experiment\"], [\"batch\"], []], da, da)\n", "einsum(\"experiment,batch->\", da, da)" ] }, { "cell_type": "markdown", "id": "6d0dacc6", "metadata": {}, "source": [ ":::{note}\n", ":class: dropdown\n", "\n", "Which again has no single operation equivalent but a multiple operation one:\n", "\n", "```python\n", "(da.sum(\"experiment\") * da.sum(\"batch\")).sum((\"batch\", \"experiment\"))\n", "```\n", ":::" ] }, { "cell_type": "markdown", "id": "987c58fa", "metadata": {}, "source": [ "**Relation to {func}`xarray.dot`**\n", "\n", "`xarray.dot` is also a wrapper on `numpy.einsum`, but it takes a single list of dimensions to operate on. This means that none of the two computations above can be reproduced with it. See for yourself:" ] }, { "cell_type": "code", "execution_count": 20, "id": "d4743bef", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.DataArray (dim: 4, dim2: 4)> Size: 128B\n",
       "32.03 68.57 42.6 101.0 40.06 76.88 59.44 ... 33.78 83.88 72.41 32.43 76.33 60.63\n",
       "Dimensions without coordinates: dim, dim2
" ], "text/plain": [ " Size: 128B\n", "32.03 68.57 42.6 101.0 40.06 76.88 59.44 ... 33.78 83.88 72.41 32.43 76.33 60.63\n", "Dimensions without coordinates: dim, dim2" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "xr.dot(da, da, dims=[\"experiment\", \"batch\"])" ] }, { "cell_type": "markdown", "id": "5f1d436e", "metadata": {}, "source": [ "`xarray.dot` is operating on both dimensions for both outputs. Therefore, to reproduce its results we need to tell `einsum` to operate on both dimensions for both inputs. This is similar to what we did a couple of examples back but now with two dimensions." ] }, { "cell_type": "code", "execution_count": 21, "id": "7d648895", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.DataArray (dim: 4, dim2: 4)> Size: 128B\n",
       "32.03 68.57 42.6 101.0 40.06 76.88 59.44 ... 33.78 83.88 72.41 32.43 76.33 60.63\n",
       "Dimensions without coordinates: dim, dim2
" ], "text/plain": [ " Size: 128B\n", "32.03 68.57 42.6 101.0 40.06 76.88 59.44 ... 33.78 83.88 72.41 32.43 76.33 60.63\n", "Dimensions without coordinates: dim, dim2" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "einsum([[\"batch\", \"experiment\"], [\"batch\", \"experiment\"], []], da, da)\n", "einsum(\"batch experiment,batch experiment->\", da, da)" ] }, { "cell_type": "markdown", "id": "bc1dc7d8", "metadata": {}, "source": [ ":::{note}\n", ":class: dropdown\n", "\n", "Similarly to the example using `dot` before, the equivalent here is a product followed by a sum over the two provided axis.\n", "\n", "```python\n", "(da * da).sum((\"batch\", \"experiment\"))\n", "```\n", ":::\n", "\n", "**`keep_dims` argument**\n", "\n", "`einsum` also has an argument to indicate dimensions that are present in multiple inputs but should be \"kept\". That is, instead of treating the dimension as the same for all inputs, its occurrence in multiple inputs should be preserved, as if they were actually different dimensions. `einsum` with then rename the repeated dimension names\n", "using the `out_append` argument.\n", "\n", "Back to using our DataArray as a collection of matrices, we might want to do a matrix multiplication. That would mean reducing the `dim2` dimensions and _keeping_ both `dim` to get the resulting collection of matrices. You can see how the default `einsum` behaviour only keeps one occurence of `dim`:" ] }, { "cell_type": "code", "execution_count": 22, "id": "960c189b", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.DataArray (batch: 10, experiment: 3, dim: 4)> Size: 960B\n",
       "2.676 19.38 0.8116 5.562 11.33 2.104 ... 6.259 12.24 6.737 0.5945 7.355 1.5\n",
       "Dimensions without coordinates: batch, experiment, dim
" ], "text/plain": [ " Size: 960B\n", "2.676 19.38 0.8116 5.562 11.33 2.104 ... 6.259 12.24 6.737 0.5945 7.355 1.5\n", "Dimensions without coordinates: batch, experiment, dim" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "einsum([[\"dim2\"], [\"dim2\"]], da, da)\n", "einsum(\"dim2,dim2\", da, da)" ] }, { "cell_type": "markdown", "id": "a20d51e2", "metadata": {}, "source": [ "We need to use the `keep_dims` argument to keep the `dim` dimension of the first DataArray as `dim` and `dim` of the 2nd DataArray as the _new_ `dim2`:" ] }, { "cell_type": "code", "execution_count": 23, "id": "a01ec920", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.DataArray (batch: 10, experiment: 3, dim: 4, dim_auto2: 4)> Size: 4kB\n",
       "2.676 6.135 1.302 3.007 6.135 19.38 2.018 ... 7.355 2.884 2.942 0.8866 2.884 1.5\n",
       "Dimensions without coordinates: batch, experiment, dim, dim_auto2
" ], "text/plain": [ " Size: 4kB\n", "2.676 6.135 1.302 3.007 6.135 19.38 2.018 ... 7.355 2.884 2.942 0.8866 2.884 1.5\n", "Dimensions without coordinates: batch, experiment, dim, dim_auto2" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "einsum([[\"dim2\"], [\"dim2\"]], da, da, keep_dims={\"dim\"}, out_append=\"_auto{i}\")\n", "einsum(\"dim2,dim2\", da, da, keep_dims={\"dim\"}, out_append=\"_auto{i}\")" ] }, { "cell_type": "markdown", "id": "147b598b", "metadata": {}, "source": [ "Note that here we had `dim, dim2` and `dim, dim2` and we have reduced `dim2` in both arguments. Therefore, we haven't done the matrix multiplication between the two arguments, but the matrix multiplication between the first argument and the _transpose_ of the second. For seamless matrix multiplication, use {func}`xarray_einstats.matmul`.\n", "\n", ":::{important}\n", "`xarray_einstats.einsum` does not support combining dimensions of different names.\n", ":::\n", "\n", "The `keep_dims` argument can also be used to perform outer products. In a pure outer product, we don't want to reduce any dimension, so we give empty lists as input \"expression\" and we pass the dimension we want to perform the outer product on as `keep_dims`:" ] }, { "cell_type": "code", "execution_count": 24, "id": "65773e8d", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.DataArray (experiment: 3, dim: 4, dim2: 4, batch: 10, batch2: 10)> Size: 38kB\n",
       "0.5006 0.09001 0.1315 0.3874 0.5949 0.6645 ... 2.931 0.2908 0.5802 0.4342 0.7379\n",
       "Dimensions without coordinates: experiment, dim, dim2, batch, batch2
" ], "text/plain": [ " Size: 38kB\n", "0.5006 0.09001 0.1315 0.3874 0.5949 0.6645 ... 2.931 0.2908 0.5802 0.4342 0.7379\n", "Dimensions without coordinates: experiment, dim, dim2, batch, batch2" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "einsum([[], []], da, da, keep_dims={\"batch\"})" ] }, { "cell_type": "code", "execution_count": 25, "id": "5e921907", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Last updated: Thu May 22 2025\n", "\n", "Python implementation: CPython\n", "Python version : 3.12.7\n", "IPython version : 8.29.0\n", "\n", "numpy: 2.2.6\n", "\n", "xarray : 2025.4.0\n", "xarray_einstats: 0.9.0\n", "\n", "Watermark: 2.5.0\n", "\n" ] } ], "source": [ "%load_ext watermark\n", "%watermark -n -u -v -iv -w -p numpy" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.7" } }, "nbformat": 4, "nbformat_minor": 5 }