Motivation
Based on prior discussions, e.g. here and here, a linalg.contract
op would be a worthwhile addition to the arsenal of Linalg named ops. The main benefit is to be able to directly represent contractions, especially ones not covered by current ops - e.g. matmul-like ops which employ arbitrary transposes/permuted dims and/or have high-dimensional operands - without needing to go to linalg.generic
s. Our main motivation is that contractions represented as linalg.contract
(rather than linalg.generic
) allow for optimizations to be more easily expressed, e.g. applying/undoing arbitrary transposes, and for straightforward lowerings, in particular to vector.contract
.
A contraction op in Linalg is also a (standalone) piece of the Linalg operation tree puzzle.
Limitations of status quo
-
We have a wild growth of matmul-like ops that we wish to halt (& eventually prune):
linalg.matmul(_transpose_a|_transpose_b)?
,linalg.batch_matmul(_transpose_a|_transpose_b)?
,linalg.batch_reduce_matmul
,linalg.matvec
,linalg.vecmat
,linalg.batch_matvec
,linalg.batch_vecmat
,linalg.mmt4d
,linalg.batch_mmt4d
,linalg.dot
- Though thereās progress, see e.g. here, these ops do not allow all the possible dim permutations people care about (e.g. selecting contracting dim(s) in a
batch_(reduce_)_matmul
). - Ever more specific ops do not give a path to a general contraction, one supporting arbitrary broadcasts and transposes on operands with arbitrary number of dims.
- The availability of an op of appropriate generality would mean the introduction of further named ops similar to the ones above would need much stronger justification.
- Though thereās progress, see e.g. here, these ops do not allow all the possible dim permutations people care about (e.g. selecting contracting dim(s) in a
-
linalg.generic
is too general for this class of einsum-like ops:- The guarantee that an opās
indexing_maps
are restricted to projected permutations means we know that a number of transforms can be applied withour reservation.- For example, transposes and broadcasts can be applied/folded into to such a contraction op itself, without going to
linalg.generic
. - The more gradual lowering to a contraction op (vs.
linalg.generic
) allows for, e.g., permutation decisions to be easily amended after the packing transform.
- For example, transposes and broadcasts can be applied/folded into to such a contraction op itself, without going to
- In practice, matching
linalg.generics
to lower tovector.contract
can be a pain ā a suitable contraction op would have a straightforward and efficient ācanonicalā lowering tovector.contract
that always applies. - Named ops, at the right abstraction level, allow for encoding invariants like op-is-a-contraction into the IR, which can, e.g.,
- be a convenient anchor for hero op matching;
- be used by tools (like the Transform dialect) to prove transforms, and compositions thereof, are valid/well-defined for every well-typed input.
- The guarantee that an opās
Proposal: introduce linalg.contract
In essence: vector.contract
but at the Linalg level.
Syntax
contract-op ::= `linalg.contract`
`indexing_maps` `=` `[` affine-map+ `]`
(`iterator_types` `=` `[` ( `parallel` | `reduction` )+ `]`)?
(`kind` `=` reduction-op)?
`ins(` $A `,` $B `:` tensor-or-memref-type `,` tensor-or-memref-type `)`
`outs(` $C `:` tensor-or-memref-type `)`
reduction-op ::= `#linalg.kind` `<` reduction-op-kind `>`
reduction-op-kind ::= `add` | `mul` | `minui` | ⦠| `minimumf`
The verifier checks
- the
indexing_maps
attribute consists ofaffine_map
s which are
1. projected permutations;
2. encode a valid contraction - reducing at least one dimension - w.r.t. the ins operandsā types;
3. the outs operandās dims are a subset of the dims of the ins operands. - if provided,
iterator_types
matches the implied iterator types of the projected permutation maps.
As with vector.contract
, see docs,
- The optional
kind
attribute controls which operator is used for reducing/combiningādefaults to standard addition; - a dim that only occurs in A or B, but not in the output, is a āfreeā dimension, one over which to reduce.
Semantics
Einsum semantics per
where I^A, I^B, and J are multi-indices, i.e. sequences/ordered sets of dimension identifiers (meant to range over index ranges) corresponding to the co-domains of the respective affine_map
s, ā is the selected kind
of reduction operation, and ā_{dims} means reduce over all valid indices for the dimensions in the set dims (NB: per the verifier, dims cannot be empty).
Example: for matmul we have I^A = ⨠m, k ā©, I^B = ⨠k, n ā©, J = ⨠m, n ā© and ā is normal addition/summation.
Like all recent linalg named ops and vector.contractās docs: āNumeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output.ā
Design choices (+ pros & cons of alternatives)
1. Require at least one contraction/reduction dimension
Rationale: is what vector.contract does; restrict generality of op
Alternative: allow elementwise products like outerproduct
Pros:
- Gain ability to represent all binary einsums
- Some strategies for rewriting trees of contractions/einsums convert contractions into elementwise ācontractionsā and back, hence it might be desirable to have one op to represent all intermediate states ā if one op is desired, a separate einsum op would probably be a more appropriate solution.
Cons:
- Would be yet another way to represent elementwise products (v.s.
linalg.mul
andlinalg.elementwise_binary
andlinalg.generic
s) which certainly complicates matching. - Lose property that each
linalg.contract
can be lowered to avector.contract
. - Would mean
linalg.contract
cannot implementContractionOpInterface
, as it requires that a contraction op āhas at least one reduction dimensionā, as the name suggests.
2. affine_maps
to encode projected dim permutations
Rationale: is what vector.contract
and recent linalg ops use
Alternative: specify reduction/contraction dims and permutation of dims as separate attributes (e.g. two arrays)
Pros:
- On the face of it more separation of concerns, though changing the permutation array could necessitate changing the reduction/contraction dims array.
Cons:
- After long discussions, transposes on
linalg.matmul
was merged with anindexing_maps
attribute instead of an attribute encoding projected permutations some other way - Per the
ContractionOpInterface
: āHas only projected permutation indexing mapsā, which, to be fair, could still be derived from the array attributes. - For lowering to
vector.contract
we would need to infer the correspondingaffine_map
s anyway.
3. iterator_types
attribute is optional
Rationale: middle ground of needing to do inference in most cases though not all and the ability to opt-in to verification.
Alternative: require iterator_types
to be provided
Pros:
- As currently implemented,
vector.contract
requires the attribute, so for lowering tovector.contract
you need it anyway. - Attribute is there as a cache, and only need to do the (linear-scan) inference in case the verifier runs.
Cons:
- Unlike for
linalg.generic
,iterator_types
can always be inferred. - Consensus was against inclusion of
iterator_types
when transposes onlinalg.matmul
(which usesindexing_maps
to permute dims) got merged recently.
Alternative: no IR representation; only available through inference
Pros:
- Less verbose IR.
- Can still be cached internally, e.g. across calls to
LinalgStructuredInterface::getIteratorTypesArray
ā potentially already by the verifier.
Cons:
- Opting-in to validation of supposed
iterator_types
by the verifier is not possible ā which, to be fair, is the case for a number of linalg ops.
4. linalg.contract
is a binary op
Rationale: Keep to convention of vector.contract
and most linalg named ops (making for easier time matching up semantics); could always be generalized later
Alternative: allow single input and/or more than two input operands
Pros:
- Multi-operand contractions are valid operations and could be part of a valid lowering strategy (e.g. they can be represented by
linalg.generic
). - Can represent more versions of einsum, closer to their original form.
Cons:
- Binary contractions suffice to implement multi-operand contractions [e.g., per a under-review paper].
- Lose existence of a ācanonicalā lowering to
vector.contract
, which is a binary op. - Single operand version is already served by
linalg.reduce
ā as above: two ways of writing the same thing (at the same abstraction level) complicates matching. ContractionOpInterface
mandates a binary op, though āIn the future, we may wish to allow more input argumentsā.
Actions
Primary / 1st PR:
- Implement the proposed abstraction, with it implementing the
ContractionOpInterface
. - Implement generalization to
linalg.generic
and lowering tovector.contract
. - Change
inferContractionDims
to additionally return āfreeā dimensions, i.e. reduction dims that occur in the LHS or RHS but not both.- To maintain current API expectations, add
allowFreeDims=false
as argument toinferContractionDims
.
- To maintain current API expectations, add
Secondary / follow-up PRs:
- Implement folding in transposes before and after the
linalg.contract
. - Rewrite
packMatmulGreedily
transform to lower matmuls tolinalg.contract
instead oflinalg.generic
.- Enables easier cleanup, e.g. fiddling with transposes after this transform has run
- Implement raising/specialization transform from
linalg.generic
tolinalg.contract
- In line with [RFC][MLIR] Linalg operation tree, implement generalization/coercion transforms to linalg.contract for
linalg.dot
linalg.matmul(_transpose_a|_transpose_b)?
linalg.batch_matmul(_transpose_a|_transpose_b)?
linalg.batch_reduce_matmul
linalg.matvec
,linalg.vecmat
linalg.batch_matvec
,linalg.batch_vecmat
linalg.mmt4d
,linalg.batch_mmt4d
Alternatives
- Generalize current collection of matmul-like ops to support higher-dimensions and more permutations.
- This would just give us
linalg.contract
though in sheepās clothing, probably spread out over a number of ops. Unlikely to yield the same generality as a proper contraction op.
- This would just give us
- Status quo: no linalg op that sits between current matmul variants and
linalg.generic
s;linalg.generic
remains only representation for non-named op contractions.- The matching story would need to be improved ā e.g. by adopting āmatch
linalg.generic
and performisContractionOpInterface
func callā as the preferred approach.- Each such scheme entails running checks with non-trivial cost on each and every
linalg.generic
, e.g. incur scans overindexing_maps
and region matching costs. - Hard to justify in face of efficient named-op matching infrastructure being available.
- Arguably, in general we do need such matching for contractions that happened to have been prematurely lowered to linalg.generic - though these ops could also be raised to
linalg.contract
, in which case you still incur the matching complexity.
- Each such scheme entails running checks with non-trivial cost on each and every
- Could introduce
projected_perm_map
attribute to be used inindexing_maps
onlinalg.generic
(& other ops) to easily identify projected permutations on dims. Would reduce the cost of matching the attributes.- An advantage would be that all multi-argument einsums expressed as a single
linalg.generic
are easier to recognize through theirindexing_maps
. - We would have overlapping representations with the
affine_map
attribute representation still being valid. - Does not yield a scheme to simplify region matching. Two such schemes that are always used together in order to constrain attributes and the opās region ought to be enough motivation for a new op ā this is what this proposal is about.
- An advantage would be that all multi-argument einsums expressed as a single
- The matching story would need to be improved ā e.g. by adopting āmatch
Unresolved questions:
- Should repeated indices for a single operand be allowed?
- The proposed semantics extends to this use case but vector.contract for example explicitly disallows it (as the corresponding affine_map is not a projected permutation). E.g. trace as an einsum is supported at least in some frameworks (as it is unambiguous in Einstein notation, also in the multi-operand case).
This RFC benefitted from comments from and discussion with @rengolin, @asiemien, Alex Heinecke & Alex Breuer.