View source on GitHub
|
Tensor contraction over specified indices and outer product.
tf.einsum(
equation, *inputs, **kwargs
)
Used in the notebooks
| Used in the guide | Used in the tutorials |
|---|---|
Einsum allows defining Tensors by defining their element-wise computation.
This computation is defined by equation, a shorthand form based on Einstein
summation. As an example, consider multiplying two matrices A and B to form a
matrix C. The elements of C are given by:
\[ C_{i,k} = \sum_j A_{i,j} B_{j,k} \]
or
C[i,k] = sum_j A[i,j] * B[j,k]
The corresponding einsum equation is:
ij,jk->ik
In general, to convert the element-wise equation into the equation string,
use the following procedure (intermediate strings for matrix multiplication
example provided in parentheses):
- remove variable names, brackets, and commas, (
ik = sum_j ij * jk) - replace "*" with ",", (
ik = sum_j ij , jk) - drop summation signs, and (
ik = ij, jk) - move the output to the right, while replacing "=" with "->". (
ij,jk->ik)
Many common operations can be expressed in this way. For example:
Matrix multiplication
m0 = tf.random.normal(shape=[2, 3])m1 = tf.random.normal(shape=[3, 5])e = tf.einsum('ij,jk->ik', m0, m1)# output[i,k] = sum_j m0[i,j] * m1[j, k]print(e.shape)(2, 5)
Repeated indices are summed if the output indices are not specified.
e = tf.einsum('ij,jk', m0, m1) # output[i,k] = sum_j m0[i,j] * m1[j, k]print(e.shape)(2, 5)
Dot product
u = tf.random.normal(shape=[5])v = tf.random.normal(shape=[5])e = tf.einsum('i,i->', u, v) # output = sum_i u[i]*v[i]print(e.shape)()
Outer product
u = tf.random.normal(shape=[3])v = tf.random.normal(shape=[5])e = tf.einsum('i,j->ij', u, v) # output[i,j] = u[i]*v[j]print(e.shape)(3, 5)
Transpose
m = tf.ones(2,3)e = tf.einsum('ij->ji', m0) # output[j,i] = m0[i,j]print(e.shape)(3, 2)
Diag
m = tf.reshape(tf.range(9), [3,3])diag = tf.einsum('ii->i', m)print(diag.shape)(3,)
Trace
# Repeated indices are summed.trace = tf.einsum('ii', m) # output[j,i] = trace(m) = sum_i m[i, i]assert trace == sum(diag)print(trace.shape)()
Batch matrix multiplication
s = tf.random.normal(shape=[7,5,3])t = tf.random.normal(shape=[7,3,2])e = tf.einsum('bij,bjk->bik', s, t)# output[a,i,k] = sum_j s[a,i,j] * t[a, j, k]print(e.shape)(7, 5, 2)
This method does not support broadcasting on named-axes. All axes with
matching labels should have the same length. If you have length-1 axes,
use tf.squeeze or tf.reshape to eliminate them.
To write code that is agnostic to the number of indices in the input use an ellipsis. The ellipsis is a placeholder for "whatever other indices fit here".
For example, to perform a NumPy-style broadcasting-batch-matrix multiplication where the matrix multiply acts on the last two axes of the input, use:
s = tf.random.normal(shape=[11, 7, 5, 3])t = tf.random.normal(shape=[11, 7, 3, 2])e = tf.einsum('...ij,...jk->...ik', s, t)print(e.shape)(11, 7, 5, 2)
Einsum will broadcast over axes covered by the ellipsis.
s = tf.random.normal(shape=[11, 1, 5, 3])t = tf.random.normal(shape=[1, 7, 3, 2])e = tf.einsum('...ij,...jk->...ik', s, t)print(e.shape)(11, 7, 5, 2)
Args | |
|---|---|
equation
|
a str describing the contraction, in the same format as
numpy.einsum.
|
*inputs
|
the inputs to contract (each one a Tensor), whose shapes should
be consistent with equation.
|
**kwargs
|
|
Returns | |
|---|---|
The contracted Tensor, with shape determined by equation.
|
Raises | |
|---|---|
ValueError
|
If
|
View source on GitHub