'iree_linalg_ext' Dialectlink
IREE Linalg Extensions.
A dialect designed for experimenting with non-structured operations that cannot be represented efficiently/directly by the Linalg dialect.
- 'iree_linalg_ext' Dialect
- Operations
- Data tiling ops
- Non-structured ops
- iree_linalg_ext.attention (LinalgExt::AttentionOp)
- iree_linalg_ext.fft (LinalgExt::FftOp)
- iree_linalg_ext.im2col (LinalgExt::Im2colOp)
- iree_linalg_ext.online_attention (LinalgExt::OnlineAttentionOp)
- iree_linalg_ext.reverse (LinalgExt::ReverseOp)
- iree_linalg_ext.scan (LinalgExt::ScanOp)
- iree_linalg_ext.scatter (LinalgExt::ScatterOp)
- iree_linalg_ext.sort (LinalgExt::SortOp)
- iree_linalg_ext.topk (LinalgExt::TopkOp)
- Utility ops
- Winograd ops
- Operations
Operationslink
Data tiling opslink
Operations for working with data layouts, padding, encodings, and other properties useful for tiling computations across iteration space dimensions.
iree_linalg_ext.pack
(LinalgExt::PackOp)link
Pack operation
Syntax:
operation ::= `iree_linalg_ext.pack` attr-dict
$inputs
(`padding_value` `(` $padding_value^ `:` type($padding_value) `)`)?
(`outer_dims_perm` `=` $outer_dims_perm^)?
`inner_dims_pos` `=` $inner_dims_pos
`inner_tiles` `=`
custom<DynamicIndexList>($inner_tiles, $static_inner_tiles)
`into` $outputs `:` `(` type($inputs) type($outputs) `)`
(`->` type($results)^)?
The pack operation converts an input
into a tiled and packed layout. The
dimensions to be tiled are obtained from inner_dims_pos
and the size of the
tile is obtained from inner_tiles
. The dimensions listed in inner_dims_pos
do not need to be contiguous in which case the tile will get transposed. We
handle only full tiles if padding_value
is not set; it is UB if the tile does
not perfectly divide the dimension. If padding_value
is set, it will pad
along high dimensions, i.e., it pads at the bottom and on the right if the
input has rank 2, and the result type shape, will be dynamic in any dimension
if and only if the input shape is. As optional input, the operation takes
outer_dims_perm
that allows to permute the tiled loops.
Example KC_to_KCck:
iree_linalg_ext.pack %arg0 inner_dims_pos = [1, 0]
inner_tiles = [32, 8] into %arg1 : (memref<128x256xf32> memref<16x8x32x8xf32>)
Example NC_to_NCnc:
iree_linalg_ext.pack %arg0 inner_dims_pos = [0, 1]
inner_tiles = [8, 32] into %arg1 : (memref<128x256xf32> memref<16x8x8x32xf32>)
iree_linalg_ext.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
inner_tiles = [32, 8] into %arg1 : (memref<128x256xf32> memref<32x4x32x8xf32>)
In all cases, dimension at position 0 in the input memref (128) is tiled with a factor of 8, while dimension at position 1 (256) is tiled with a factor of 32. In the KC_to_KCck example, the point loops are interchanged, while in the KC_to_CKkc example the tiled loops.
Example NC_to_NCnc with padding:
iree_linalg_ext.pack %arg padding_value(%pad : f32) inner_dims_pos = [0, 1]
inner_tiles = [8, 2] into %arg1 : (memref<13x15xf32> memref<2x8x8x2xf32>)
Traits: AttrSizedOperandSegments
, SingleBlockImplicitTerminator<::mlir::iree_compiler::IREE::LinalgExt::YieldOp>
, SingleBlock
Interfaces: DestinationStyleOpInterface
, LinalgExtInterface
, LinalgExtOp
, MemoryEffectOpInterface
, ReifyRankedShapedTypeOpInterface
, TilingInterface
Attributes:link
Attribute | MLIR Type | Description |
---|---|---|
outer_dims_perm | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
inner_dims_pos | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
static_inner_tiles | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
Operands:link
Operand | Description |
---|---|
inputs |
variadic of shaped of any type values |
outputs |
variadic of shaped of any type values |
inner_tiles |
variadic of index |
padding_value |
any type |
Results:link
Result | Description |
---|---|
results |
variadic of ranked tensor of any type values |
iree_linalg_ext.unpack
(LinalgExt::UnPackOp)link
Unpack operation
Syntax:
operation ::= `iree_linalg_ext.unpack` attr-dict
$inputs
(`outer_dims_perm` `=` $outer_dims_perm^)?
`inner_dims_pos` `=` $inner_dims_pos
`inner_tiles` `=`
custom<DynamicIndexList>($inner_tiles, $static_inner_tiles)
`into` $outputs `:` `(` type($inputs) type($outputs) `)`
(`->` type($results)^)?
The unpack operation converts a tiled and packed input to an unpacked
output. See pack
for more details on inner_tiles
and dims_pos
; it is UB
if the tile does not perfectly divide the dimension. Optionally, the operation
also supports permuting the tiled loops.
Example KCck_to_KC:
iree_linalg_ext.unpack %arg0 dims_pos = [1, 0]
inner_tiles = [32, 8] into %arg1 : (memref<16x8x32x8xf32> memref<128x256xf32>)
Example NCnc_to_NC:
iree_linalg_ext.unpack %arg0 dims_pos = [0, 1]
inner_tiles = [8, 32] into %arg1 : (memref<16x8x8x32xf32> memref<128x256xf32>)
Example CKkc_to_KC:
iree_linalg_ext.unpack %arg1 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
inner_tiles = [32, 8] into %arg0 : (memref<32x4x32x8xf32> memref<128x256xf32>)
Traits: AttrSizedOperandSegments
, SingleBlockImplicitTerminator<::mlir::iree_compiler::IREE::LinalgExt::YieldOp>
, SingleBlock
Interfaces: DestinationStyleOpInterface
, LinalgExtInterface
, LinalgExtOp
, MemoryEffectOpInterface
, ReifyRankedShapedTypeOpInterface
, TilingInterface
Attributes:link
Attribute | MLIR Type | Description |
---|---|---|
outer_dims_perm | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
inner_dims_pos | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
static_inner_tiles | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
Operands:link
Operand | Description |
---|---|
inputs |
variadic of shaped of any type values |
outputs |
variadic of shaped of any type values |
inner_tiles |
variadic of index |
Results:link
Result | Description |
---|---|
results |
variadic of ranked tensor of any type values |
Non-structured opslink
iree_linalg_ext.attention
(LinalgExt::AttentionOp)link
Attention operator
Syntax:
operation ::= `iree_linalg_ext.attention` attr-dict
`ins` `(` $inputs `:` type($inputs) `)`
`outs` `(` $outputs `:` type($outputs) `)`
(`->` type($results)^)?
Computes the scaled dot product attention function:
attention(Q, K, V, scale) = softmax(Q @ K.T * scale) @ V
Here Q, K, V are given tensors and scale is a scalar value specifying the scale to use.
For self-attention, all inputs and the result have the same shape BxNxd where B is the batch dimension, N is the sequence length and d is head dimension. Typically N >>> d. Usually, this operator also performs masking and dropout, but we leave that out of the current implementation. For cross-attention, the query and output have the same shape (BxNxd), while the key and value differ in sequence length (they have shape BxLxd, where L != N).
This operator after tiling results in a tiled result as per
FlashAttention 2 and optionally results in the current max
and sum
statistics while processing the current tile.
If transpose_v is speciifed, the V tensor passed as input is assumed to be transposed:
attention(Q, K, V, scale) = softmax(Q @ K.T * scale) @ V.T
TODO: We should be moving to using a indexing map like approach so we can generalize which tensor is transposed and which is not.
Traits: AttrSizedOperandSegments
, SingleBlockImplicitTerminator<::mlir::iree_compiler::IREE::LinalgExt::YieldOp>
, SingleBlock
Interfaces: DestinationStyleOpInterface
, LinalgExtInterface
, MemoryEffectOpInterface
, ReifyRankedShapedTypeOpInterface
, TilingInterface
Attributes:link
Attribute | MLIR Type | Description |
---|---|---|
transpose_v | ::mlir::BoolAttr | bool attribute |
Operands:link
Operand | Description |
---|---|
inputs |
variadic of any type |
outputs |
variadic of shaped of any type values |
Results:link
Result | Description |
---|---|
results |
variadic of ranked tensor of any type values |
iree_linalg_ext.fft
(LinalgExt::FftOp)link
Fft operator
Syntax:
operation ::= `iree_linalg_ext.fft` attr-dict (`ins` `(` $inputs^ `:` type($inputs) `)`)?
`outs` `(` $outputs `:` type($outputs) `)`
(`:` type($results)^)?
Apply 1D FFT to innermost dim. This is an iterative FFT, not recurrsive. Thus, the bit reversal is assumed applied on the input. The op carries an input -- stage, which indicates the level of reduction loop in the algorithm. It represents the computation body. For more details, see "Data reordering, bit reversal, and in-place algorithms" section in https://en.wikipedia.org/wiki/Cooley%E2%80%93Tukey_FFT_algorithm
The size of innermost dim is expected to be a power of 2.
It is optional to carry coefficient tensors/buffers as inputs. In this context, they will be the second and third inputs.
Traits: AttrSizedOperandSegments
, SingleBlockImplicitTerminator<::mlir::iree_compiler::IREE::LinalgExt::YieldOp>
, SingleBlock
Interfaces: DestinationStyleOpInterface
, LinalgExtInterface
, MemoryEffectOpInterface
, ReifyRankedShapedTypeOpInterface
, TilingInterface
Operands:link
Operand | Description |
---|---|
inputs |
variadic of any type |
outputs |
variadic of shaped of any type values |
Results:link
Result | Description |
---|---|
results |
variadic of ranked tensor of any type values |
iree_linalg_ext.im2col
(LinalgExt::Im2colOp)link
Im2col operation for convolutions
Syntax:
operation ::= `iree_linalg_ext.im2col` attr-dict
`strides` `=` $strides
`dilations` `=` $dilations
`kernel_size` `=`
custom<DynamicIndexList>($kernel_size, $static_kernel_size)
`m_offset` `=`
custom<DynamicIndexList>($m_offset, $static_m_offset)
`k_offset` `=`
custom<DynamicIndexList>($k_offset, $static_k_offset)
`batch_pos` `=` $batch_pos
`m_pos` `=` $m_pos
`k_pos` `=` $k_pos
`ins` `(` $input `:` type($input) `)`
`outs` `(` $output `:` type($output) `)`
(`->` type($results)^)?
Im2col op for convolutions. The operation performs a transformation on the
input to convert it from a convolution input to an equivalent gemm input.
The op is defined by its input, output, some conv metadata, and some
indexing metadata. The strides
, dilations
, and kernel_size
are taken
from the convolution from which this op is generated, and they define how
the input operand is indexed when the operation is decomposed. The shape of
the output should be tensor<BxMxK>
, and the m_pos
, k_pos
, and
batch_pos
indicate which input dimensions map to which output dimensions.
The k_offset
is an offset within the output K dimension from which the
iteration space of the operation begins. This is used for tiling, since the
tiled implementation must leave the output K dimension untiled. Similarly,
m_offset
is the offset within the output M dimension from which the
iteration space of the operation begins.
The iteration space is the full output shape of the im2col op, so if the
im2col op were tiled to loops with a scalar inner tile, it would look like
the following:
%im2col = iree_linalg_ext.im2col
strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
m_offset = [0] k_offset = [0]
batch_pos = [0] m_pos = [1, 2] k_pos = [3]
ins(%in : tensor<2x34x34x640xf32>)
outs(%out : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
scf.for %arg0 = %c0 to %c2 step %c1
scf.for %arg1 = %c0 to %c1024 step %c1
scf.for %arg2 = %c0 to %c5760 step %c1
%im2col = iree_linalg_ext.im2col
strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
m_offset = [%arg1] k_offset = [%arg2]
batch_pos = [0] m_pos = [1, 2] k_pos = [3]
ins(%in_tile : tensor<1x34x34x640xf32>)
outs(%out_tile : tensor<1x1x1xf32>) -> tensor<1x1x1xf32>
%in_tile
followed
by an insert_slice to the %out_tile
. The indices for the extract slice are
computed using the m_offset
and k_offset
as:
(b, m, k) -> (b, M / 32 + K / (6403), M % 32 + K % (6403) / 640, K % 640)
Where (b, m, k)
are the indices of the tiled op's iteration space, and
M = m + m_offset
and K = k + K_offset
.
Traits: AttrSizedOperandSegments
, SingleBlockImplicitTerminator<::mlir::iree_compiler::IREE::LinalgExt::YieldOp>
, SingleBlock
Interfaces: DestinationStyleOpInterface
, LinalgExtInterface
, MemoryEffectOpInterface
, ReifyRankedShapedTypeOpInterface
Attributes:link
Attribute | MLIR Type | Description |
---|---|---|
strides | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
dilations | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
static_kernel_size | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
static_m_offset | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
static_k_offset | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
batch_pos | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
m_pos | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
k_pos | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
Operands:link
Operand | Description |
---|---|
input |
shaped of any type values |
output |
shaped of any type values |
kernel_size |
variadic of index |
m_offset |
variadic of index |
k_offset |
variadic of index |
Results:link
Result | Description |
---|---|
results |
variadic of shaped of any type values |
iree_linalg_ext.online_attention
(LinalgExt::OnlineAttentionOp)link
Online Attention operator
Syntax:
operation ::= `iree_linalg_ext.online_attention` attr-dict
`ins` `(` $query `,` $key `,` $value `,` $scale `:` type($query) `,` type($key) `,` type($value) `,` type($scale) `)`
`outs` `(` $output `,` $max `,` $sum `:` type($output) `,` type($max) `,` type($sum) `)`
(`->` type($results)^)?
Traditional scaled dot product attention computes:
attention(Q, K, V, scale) = softmax(Q @ K.T * scale) @ V
Online Attention on the other hand, uses an online normalizer instead of softmax:
online_attention(Q, K, V, scale, running_max, running_sum) = online_normalizer(Q @ K.T * scale, running_max, running_sum) @ V
The advantage of this online_normalizer is that it can be tiled along it's reduction dimension, making the online_attention operator: - Tilable along softmax reduction dimension - Associative along softmax reduction dimension - Commutative along softmax associative dimension
Note: The results of online_attention need to be combined after computing it over the entire softmax reduction dimension by: x, _, sum : results x = (1 / sum) * x
Interfaces: AggregatedOpInterface
, DestinationStyleOpInterface
, LinalgExtInterface
, MemoryEffectOpInterface
, ReifyRankedShapedTypeOpInterface
, TilingInterface
Attributes:link
Attribute | MLIR Type | Description |
---|---|---|
indexing_maps | ::mlir::ArrayAttr | AffineMap array attribute |
Operands:link
Operand | Description |
---|---|
query |
shaped of any type values |
key |
shaped of any type values |
value |
shaped of any type values |
scale |
floating-point |
output |
shaped of any type values |
max |
shaped of any type values |
sum |
shaped of any type values |
Results:link
Result | Description |
---|---|
results |
variadic of ranked tensor of any type values |
iree_linalg_ext.reverse
(LinalgExt::ReverseOp)link
Reverse operator
Syntax:
operation ::= `iree_linalg_ext.reverse` attr-dict `dimensions` `(` $dimensions `)`
(`ins` `(` $inputs^ `:` type($inputs) `)`)?
(`outs` `(` $outputs^ `:` type($outputs) `)`)?
(`:` type($results)^)?
A temporary solution for lowering reverse ops into IREE, allowing IREE to tile and distribute them. }
Traits: AttrSizedOperandSegments
, SingleBlockImplicitTerminator<::mlir::iree_compiler::IREE::LinalgExt::YieldOp>
, SingleBlock
Interfaces: DestinationStyleOpInterface
, LinalgExtInterface
, LinalgExtOp
, LinalgFusionOpInterface
, MemoryEffectOpInterface
, ReifyRankedShapedTypeOpInterface
, TilingInterface
Attributes:link
Attribute | MLIR Type | Description |
---|---|---|
dimensions | ::mlir::DenseIntElementsAttr | 64-bit signless integer elements attribute |
Operands:link
Operand | Description |
---|---|
inputs |
variadic of shaped of any type values |
outputs |
variadic of shaped of any type values |
Results:link
Result | Description |
---|---|
results |
variadic of ranked tensor of any type values |
iree_linalg_ext.scan
(LinalgExt::ScanOp)link
Scan operator
Syntax:
operation ::= `iree_linalg_ext.scan` attr-dict
`dimension` `(` $dimension `)`
`inclusive` `(` $inclusive `)`
`ins` `(` $inputs `:` type($inputs) `)`
`outs` `(` $outputs `:` type($outputs) `)`
$region (`->` type($results)^)?
Computes the inclusive/exclusive scan along a given dimension.
Traits: AttrSizedOperandSegments
, SingleBlockImplicitTerminator<::mlir::iree_compiler::IREE::LinalgExt::YieldOp>
, SingleBlock
Interfaces: DestinationStyleOpInterface
, LinalgExtInterface
, MemoryEffectOpInterface
, ReifyRankedShapedTypeOpInterface
, TilingInterface
Attributes:link
Attribute | MLIR Type | Description |
---|---|---|
dimension | ::mlir::IntegerAttr | 64-bit signless integer attribute |
inclusive | ::mlir::BoolAttr | bool attribute |
Operands:link
Operand | Description |
---|---|
inputs |
variadic of shaped of any type values |
outputs |
variadic of shaped of any type values |
Results:link
Result | Description |
---|---|
results |
variadic of ranked tensor of any type values |
iree_linalg_ext.scatter
(LinalgExt::ScatterOp)link
Scatter operator
Syntax:
operation ::= `iree_linalg_ext.scatter` attr-dict `dimension_map` `=` $dimension_map
`unique_indices` `(` $unique_indices `)`
(`ins` `(` $inputs^ `:` type($inputs) `)`)?
`outs` `(` $outputs `:` type($outputs) `)`
$region (`->` type($results)^)?
Based on XLA operation semantics, takes two inputs
(update
and
indices
) and outputs
value (original
). The operation updates
the value at the slices specified by indices
by combining the
current value with the value in updates
using the computation
specified in region
. The region
specifies a binary operation
of signature (T, T) -> T, where T
is the element-type of
updates
(and original
). The first argument correspond the
value to be updated (i.e. from updates
), and the second the
current value (i.e. value from original
).
The indices
is a 2D tensor/memref type. The first dim is the number of
updates, and the second dim is index depth. The index depth should always be
static.
The first dim of updates
and indices
is identical, since they represent
the number of updates.
The rank of the original
/result
is at least
index_depth + rank(%updates) - 1
. The first index_depth
indices are
derived from indices
and the shape of update value has the last
rank(%original) - index_depth values match %(originals) last dimensions,
with the previous dims extending from the index offsets.
The dimension_map attributes describes which index value maps to which dimension in the destionation. It cannot contain duplicate values, must have as many entries as index depth, and values must be within the rank of the destination.
The unique_indices attribute carries the information whether all the indices are unique. If there are repeated indices, the first iteration loop will be marked as reduction.
The shapes definition follows tensorflow operations execept that it force batch dims to be 1D. See more information in https://www.tensorflow.org/api_docs/python/tf/tensor_scatter_nd_update
Traits: AttrSizedOperandSegments
, SingleBlockImplicitTerminator<::mlir::iree_compiler::IREE::LinalgExt::YieldOp>
, SingleBlock
Interfaces: DestinationStyleOpInterface
, LinalgExtInterface
, LinalgFusionOpInterface
, MemoryEffectOpInterface
, ReifyRankedShapedTypeOpInterface
, TilingInterface
Attributes:link
Attribute | MLIR Type | Description |
---|---|---|
dimension_map | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
unique_indices | ::mlir::BoolAttr | bool attribute |
Operands:link
Operand | Description |
---|---|
inputs |
variadic of ranked tensor or memref of any type values |
outputs |
variadic of ranked tensor or memref of any type values |
Results:link
Result | Description |
---|---|
results |
variadic of ranked tensor of any type values |
iree_linalg_ext.sort
(LinalgExt::SortOp)link
Sort operator
Syntax:
operation ::= `iree_linalg_ext.sort` attr-dict
`dimension` `(` $dimension `)`
(`ins` `(` $inputs^ `:` type($inputs) `)`)?
`outs` `(` $outputs `:` type($outputs) `)`
$region (`->` type($results)^)?
Based on XLA operation semantics, sorts the given operands
at the given
dimension
with the given comparator
.
See https://www.tensorflow.org/xla/operation_semantics#sort.
Traits: AttrSizedOperandSegments
, SingleBlockImplicitTerminator<::mlir::iree_compiler::IREE::LinalgExt::YieldOp>
, SingleBlock
Interfaces: DestinationStyleOpInterface
, LinalgExtInterface
, MemoryEffectOpInterface
, ReifyRankedShapedTypeOpInterface
, TilingInterface
Attributes:link
Attribute | MLIR Type | Description |
---|---|---|
dimension | ::mlir::IntegerAttr | 64-bit signless integer attribute |
Operands:link
Operand | Description |
---|---|
inputs |
variadic of any type |
outputs |
variadic of shaped of any type values |
Results:link
Result | Description |
---|---|
results |
variadic of ranked tensor of any type values |
iree_linalg_ext.topk
(LinalgExt::TopkOp)link
Top-K operator
Syntax:
operation ::= `iree_linalg_ext.topk` attr-dict
`dimension` `(` $dimension `)`
`ins` `(` $inputs `:` type($inputs) `)`
`outs` `(` $outputs `:` type($outputs) `)`
$region (`->` type($results)^)?
A Top-K operation for N-D tensors. Reduces the target dimension from the input size N down to K elements based on the supplied binary region.
Accepts an N-D tensor input consisting of values and an optioanl N-D tensor for indices of those values (i32 type). If input indices aren't provided, the index mapping is inferred based on the k dim. Both input values/indices tensors and output values/indicies tensors must have the same shape. Top-K is computed along the target dimension (from dimension()). Returns two output tensors of values and the indicies of Top-K results. The output dimensions must match the input save for the dimension that is reduced to K results.
Region accepts lhs=[next N input] and rhs=[exiting K output] and yeilds an i1. If true, the two values are swapped: - For Top-K compoarision: > - For Min-K comparision: < Note: when the two values are equal, the first occurence is always selected.
Traits: AttrSizedOperandSegments
, SingleBlockImplicitTerminator<::mlir::iree_compiler::IREE::LinalgExt::YieldOp>
, SingleBlock
Interfaces: DestinationStyleOpInterface
, LinalgExtInterface
, LinalgExtOp
, MemoryEffectOpInterface
, ReifyRankedShapedTypeOpInterface
, TilingInterface
Attributes:link
Attribute | MLIR Type | Description |
---|---|---|
dimension | ::mlir::IntegerAttr | 64-bit signless integer attribute |
Operands:link
Operand | Description |
---|---|
inputs |
variadic of shaped of any type values |
outputs |
variadic of shaped of any type values |
Results:link
Result | Description |
---|---|
results |
variadic of ranked tensor of any type values |
Utility opslink
iree_linalg_ext.yield
(LinalgExt::YieldOp)link
LinalgExt yield op
Syntax:
operation ::= `iree_linalg_ext.yield` attr-dict ($operands^ `:` type($operands))?
iree_linalg_ext.yield
is a special terminator operation for blocks inside
regions in iree_linalg_ext
ops.
Traits: AlwaysSpeculatableImplTrait
, ReturnLike
, Terminator
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
, RegionBranchTerminatorOpInterface
Effects: MemoryEffects::Effect{}
Operands:link
Operand | Description |
---|---|
operands |
variadic of any type |
Winograd opslink
iree_linalg_ext.winograd.filter_transform
(LinalgExt::WinogradFilterTransformOp)link
Winograd Filter Transform operator
Syntax:
operation ::= `iree_linalg_ext.winograd.filter_transform` attr-dict
`output_tile_size` `(` $output_tile_size `)`
`kernel_size` `(` $kernel_size `)`
`kernel_dimensions` `(` $kernel_dimensions `)`
`ins` `(` $inputs `:` type($inputs) `)`
`outs` `(` $outputs `:` type($outputs) `)`
(`->` type($result)^)?
This operator is part of the first step in converting a convolution to its Winograd equivalent. Given a tile of a convolution filter (F), this operator computes matmul(G, matmul(F, transpose(B))). The filter tile is assumed to be the full m x m convolutional kernel, and the result of the transformation on this tile is a square with each side of size m + r - 1, where the output tile size is r x r. G is a constant 2-d matrix of shape (m + r - 1) x m. The input to the operator is a filter of shape (H, W, C, F) or (F, C, H, W) and the output is an operator of shape (m + r - 1, m + r - 1, C, F). The result of this operator is first collapsed and then fed to a batch matmul op.
Traits: AttrSizedOperandSegments
, SingleBlockImplicitTerminator<::mlir::iree_compiler::IREE::LinalgExt::YieldOp>
, SingleBlock
Interfaces: DestinationStyleOpInterface
, LinalgExtInterface
, MemoryEffectOpInterface
, ReifyRankedShapedTypeOpInterface
, TilingInterface
Attributes:link
Attribute | MLIR Type | Description |
---|---|---|
output_tile_size | ::mlir::IntegerAttr | 64-bit signless integer attribute |
kernel_size | ::mlir::IntegerAttr | 64-bit signless integer attribute |
kernel_dimensions | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
Operands:link
Operand | Description |
---|---|
inputs |
variadic of shaped of any type values |
outputs |
variadic of shaped of any type values |
Results:link
Result | Description |
---|---|
result |
variadic of ranked tensor of any type values |
iree_linalg_ext.winograd.input_transform
(LinalgExt::WinogradInputTransformOp)link
Winograd Input Transform operator
Syntax:
operation ::= `iree_linalg_ext.winograd.input_transform` attr-dict
`output_tile_size` `(` $output_tile_size `)`
`kernel_size` `(` $kernel_size `)`
`image_dimensions` `(` $image_dimensions `)`
`ins` `(` $inputs `:` type($inputs) `)`
`outs` `(` $outputs `:` type($outputs) `)`
(`->` type($result)^)?
This operator is part of the first step in converting a convolution to its Winograd equivalent. Given a tile of an input image (I), this operator computes matmul(tranpose(B), matmul(I, B)). The input tile is assumed to be square with each side of size m + r - 1, where the convolutional kernel is m x m and the output tile size is r x r. B is a constant 2-d square matrix of the same shape as the input tile I. The input to the operator is an image of shape (N, H, W, C) or (N, C, H, W) and the output is an operator of shape (m + r - 1, m + r - 1, N, H', W', C) where H' = ceil((H - m + 1)/r) and W' = ceil((W - m + 1)/r). The result of this operator is first collapsed and then fed to a batch matmul op.
Traits: AttrSizedOperandSegments
, SingleBlockImplicitTerminator<::mlir::iree_compiler::IREE::LinalgExt::YieldOp>
, SingleBlock
Interfaces: DestinationStyleOpInterface
, LinalgExtInterface
, MemoryEffectOpInterface
, ReifyRankedShapedTypeOpInterface
, TilingInterface
Attributes:link
Attribute | MLIR Type | Description |
---|---|---|
output_tile_size | ::mlir::IntegerAttr | 64-bit signless integer attribute |
kernel_size | ::mlir::IntegerAttr | 64-bit signless integer attribute |
image_dimensions | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
Operands:link
Operand | Description |
---|---|
inputs |
variadic of shaped of any type values |
outputs |
variadic of shaped of any type values |
Results:link
Result | Description |
---|---|
result |
variadic of ranked tensor of any type values |
iree_linalg_ext.winograd.output_transform
(LinalgExt::WinogradOutputTransformOp)link
Winograd Output Transform operator
Syntax:
operation ::= `iree_linalg_ext.winograd.output_transform` attr-dict
`output_tile_size` `(` $output_tile_size `)`
`kernel_size` `(` $kernel_size `)`
`image_dimensions` `(` $image_dimensions `)`
`ins` `(` $inputs `:` type($inputs) `)`
`outs` `(` $outputs `:` type($outputs) `)`
(`->` type($result)^)?
This operator is the last transform in converting a convolution to its Winograd equivalent. After convolution in the Winograd domain (which turns into an elementwise product for a single channel and batch matrix multiplication for many channels), this operator converts the output back into the original domain. Given a tile of the output (O) in the Winograd domain, this operator computes matmul(transpose(A), matmul(O, A)). The output tile is square with each side of size m + r - 1, where the convolutional kernel is m x m and the output tile size is r x r. A is a constant 2-d matrix of shape (m + r - 1) x r. The input to the operator is a tensor of shape (m + r - 1, m + r - 1, N, H', W', C) and the output is a tensor of shape (N, H, W, C) or (N, C, H, W) where H = r H' and W = r W'. This operator is followed by a tensor.extract_slice which extracts only the non-padded part of the output.
Traits: AttrSizedOperandSegments
, SingleBlockImplicitTerminator<::mlir::iree_compiler::IREE::LinalgExt::YieldOp>
, SingleBlock
Interfaces: DestinationStyleOpInterface
, LinalgExtInterface
, MemoryEffectOpInterface
, ReifyRankedShapedTypeOpInterface
, TilingInterface
Attributes:link
Attribute | MLIR Type | Description |
---|---|---|
output_tile_size | ::mlir::IntegerAttr | 64-bit signless integer attribute |
kernel_size | ::mlir::IntegerAttr | 64-bit signless integer attribute |
image_dimensions | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
Operands:link
Operand | Description |
---|---|
inputs |
variadic of shaped of any type values |
outputs |
variadic of shaped of any type values |
Results:link
Result | Description |
---|---|
result |
variadic of ranked tensor of any type values |