Skip to content

'iree_linalg_ext' Dialectlink

IREE Linalg Extensions.

A dialect designed for experimenting with non-structured operations that cannot be represented efficiently/directly by the Linalg dialect.

Operationslink

Data tiling opslink

Operations for working with data layouts, padding, encodings, and other properties useful for tiling computations across iteration space dimensions.

iree_linalg_ext.pack (LinalgExt::PackOp)link

Pack operation

Syntax:

operation ::= `iree_linalg_ext.pack` attr-dict
              $inputs
              (`padding_value` `(` $padding_value^ `:` type($padding_value) `)`)?
              (`outer_dims_perm` `=` $outer_dims_perm^)?
              `inner_dims_pos` `=` $inner_dims_pos
              `inner_tiles` `=`
              custom<DynamicIndexList>($inner_tiles, $static_inner_tiles)
              `into` $outputs `:` `(` type($inputs) type($outputs) `)`
              (`->` type($results)^)?

The pack operation converts an input into a tiled and packed layout. The dimensions to be tiled are obtained from inner_dims_pos and the size of the tile is obtained from inner_tiles. The dimensions listed in inner_dims_pos do not need to be contiguous in which case the tile will get transposed. We handle only full tiles if padding_value is not set; it is UB if the tile does not perfectly divide the dimension. If padding_value is set, it will pad along high dimensions, i.e., it pads at the bottom and on the right if the input has rank 2, and the result type shape, will be dynamic in any dimension if and only if the input shape is. As optional input, the operation takes outer_dims_perm that allows to permute the tiled loops.

Example KC_to_KCck:

iree_linalg_ext.pack %arg0 inner_dims_pos = [1, 0]
  inner_tiles = [32, 8] into %arg1 : (memref<128x256xf32> memref<16x8x32x8xf32>)

Example NC_to_NCnc:

iree_linalg_ext.pack %arg0 inner_dims_pos = [0, 1]
  inner_tiles = [8, 32] into %arg1 : (memref<128x256xf32> memref<16x8x8x32xf32>)
Example KC_to_CKkc

iree_linalg_ext.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
  inner_tiles = [32, 8] into %arg1 : (memref<128x256xf32> memref<32x4x32x8xf32>)

In all cases, dimension at position 0 in the input memref (128) is tiled with a factor of 8, while dimension at position 1 (256) is tiled with a factor of 32. In the KC_to_KCck example, the point loops are interchanged, while in the KC_to_CKkc example the tiled loops.

Example NC_to_NCnc with padding:

iree_linalg_ext.pack %arg padding_value(%pad : f32) inner_dims_pos = [0, 1]
  inner_tiles = [8, 2] into %arg1 : (memref<13x15xf32> memref<2x8x8x2xf32>)

Traits: AttrSizedOperandSegments, SingleBlockImplicitTerminator<::mlir::iree_compiler::IREE::LinalgExt::YieldOp>, SingleBlock

Interfaces: DestinationStyleOpInterface, LinalgExtInterface, LinalgExtOp, MemoryEffectOpInterface, ReifyRankedShapedTypeOpInterface, TilingInterface

Attributes:link
AttributeMLIR TypeDescription
outer_dims_perm::mlir::DenseI64ArrayAttri64 dense array attribute
inner_dims_pos::mlir::DenseI64ArrayAttri64 dense array attribute
static_inner_tiles::mlir::DenseI64ArrayAttri64 dense array attribute
Operands:link
Operand Description
inputs variadic of shaped of any type values
outputs variadic of shaped of any type values
inner_tiles variadic of index
padding_value any type
Results:link
Result Description
results variadic of ranked tensor of any type values

iree_linalg_ext.set_encoding (LinalgExt::SetEncodingOp)link

Perform pack and pad operation on source

Syntax:

operation ::= `iree_linalg_ext.set_encoding` attr-dict $source `:` type($source) `->` type($result)

Operation to assign an encoding to a tensor. The operation does not change the rank or extent of a tensor. Instead it adds an encoding attribute to the tensor type to represent a change in layout.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), ReifyRankedShapedTypeOpInterface

Effects: MemoryEffects::Effect{}

Operands:link
Operand Description
source ranked tensor of any type values
Results:link
Result Description
result ranked tensor of any type values

iree_linalg_ext.unpack (LinalgExt::UnPackOp)link

Unpack operation

Syntax:

operation ::= `iree_linalg_ext.unpack` attr-dict
              $inputs
              (`outer_dims_perm` `=` $outer_dims_perm^)?
              `inner_dims_pos` `=` $inner_dims_pos
              `inner_tiles` `=`
              custom<DynamicIndexList>($inner_tiles, $static_inner_tiles)
              `into` $outputs `:` `(` type($inputs) type($outputs) `)`
              (`->` type($results)^)?

The unpack operation converts a tiled and packed input to an unpacked output. See pack for more details on inner_tiles and dims_pos; it is UB if the tile does not perfectly divide the dimension. Optionally, the operation also supports permuting the tiled loops.

Example KCck_to_KC:

iree_linalg_ext.unpack %arg0 dims_pos = [1, 0]
  inner_tiles = [32, 8] into %arg1 : (memref<16x8x32x8xf32> memref<128x256xf32>)

Example NCnc_to_NC:

iree_linalg_ext.unpack %arg0 dims_pos = [0, 1]
  inner_tiles = [8, 32] into %arg1 : (memref<16x8x8x32xf32> memref<128x256xf32>)

Example CKkc_to_KC:

iree_linalg_ext.unpack %arg1 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
  inner_tiles = [32, 8] into %arg0 : (memref<32x4x32x8xf32> memref<128x256xf32>)

Traits: AttrSizedOperandSegments, SingleBlockImplicitTerminator<::mlir::iree_compiler::IREE::LinalgExt::YieldOp>, SingleBlock

Interfaces: DestinationStyleOpInterface, LinalgExtInterface, LinalgExtOp, MemoryEffectOpInterface, ReifyRankedShapedTypeOpInterface, TilingInterface

Attributes:link
AttributeMLIR TypeDescription
outer_dims_perm::mlir::DenseI64ArrayAttri64 dense array attribute
inner_dims_pos::mlir::DenseI64ArrayAttri64 dense array attribute
static_inner_tiles::mlir::DenseI64ArrayAttri64 dense array attribute
Operands:link
Operand Description
inputs variadic of shaped of any type values
outputs variadic of shaped of any type values
inner_tiles variadic of index
Results:link
Result Description
results variadic of ranked tensor of any type values

iree_linalg_ext.unset_encoding (LinalgExt::UnsetEncodingOp)link

Perfom unpack and extract operation on source

Syntax:

operation ::= `iree_linalg_ext.unset_encoding` attr-dict $source `:` type($source) `->` type($result)

Operation to convert an tensor with encoding that represents its data layout into a tensor with default layout (i.e. no encoding). For now in IREE the default layout is row-major.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), ReifyRankedShapedTypeOpInterface

Effects: MemoryEffects::Effect{}

Operands:link
Operand Description
source ranked tensor of any type values
Results:link
Result Description
result ranked tensor of any type values

iree_linalg_ext.upper_bound_tile_size (LinalgExt::UpperBoundTileSizeOp)link

Returns an upper bound on tile sizes

Syntax:

operation ::= `iree_linalg_ext.upper_bound_tile_size` attr-dict $tensorType `->` type($results)

This returns the largest tile sizes that might result from materialization of the given encoding. This can be used outside of target-specific code, so there may be multiple targets, and this will return the maximum tile size from iterating over all of them. The evaluation happens in the MaterializeUpperBoundTileSize pass.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes:link
AttributeMLIR TypeDescription
tensorType::mlir::TypeAttrtype attribute of ranked tensor of any type values
Results:link
Result Description
results variadic of index

Non-structured opslink

iree_linalg_ext.attention (LinalgExt::AttentionOp)link

Attention operator

Syntax:

operation ::= `iree_linalg_ext.attention` attr-dict
              `ins` `(` $inputs `:` type($inputs) `)`
              `outs` `(` $outputs `:` type($outputs) `)`
              (`->` type($results)^)?

Computes the scaled dot product attention function:

attention(Q, K, V, scale) = softmax(Q @ K.T * scale) @ V

Here Q, K, V are given tensors and scale is a scalar value specifying the scale to use.

For self-attention, all inputs and the result have the same shape BxNxd where B is the batch dimension, N is the sequence length and d is head dimension. Typically N >>> d. Usually, this operator also performs masking and dropout, but we leave that out of the current implementation. For cross-attention, the query and output have the same shape (BxNxd), while the key and value differ in sequence length (they have shape BxLxd, where L != N).

This operator after tiling results in a tiled result as per FlashAttention 2 and optionally results in the current max and sum statistics while processing the current tile.

If transpose_v is speciifed, the V tensor passed as input is assumed to be transposed:

attention(Q, K, V, scale) = softmax(Q @ K.T * scale) @ V.T

TODO: We should be moving to using a indexing map like approach so we can generalize which tensor is transposed and which is not.

Traits: AttrSizedOperandSegments, SingleBlockImplicitTerminator<::mlir::iree_compiler::IREE::LinalgExt::YieldOp>, SingleBlock

Interfaces: DestinationStyleOpInterface, LinalgExtInterface, MemoryEffectOpInterface, ReifyRankedShapedTypeOpInterface, TilingInterface

Attributes:link
AttributeMLIR TypeDescription
transpose_v::mlir::BoolAttrbool attribute
Operands:link
Operand Description
inputs variadic of any type
outputs variadic of shaped of any type values
Results:link
Result Description
results variadic of ranked tensor of any type values

iree_linalg_ext.fft (LinalgExt::FftOp)link

Fft operator

Syntax:

operation ::= `iree_linalg_ext.fft` attr-dict (`ins` `(` $inputs^ `:` type($inputs) `)`)?
              `outs` `(` $outputs `:` type($outputs) `)`
              (`:` type($results)^)?

Apply 1D FFT to innermost dim. This is an iterative FFT, not recurrsive. Thus, the bit reversal is assumed applied on the input. The op carries an input -- stage, which indicates the level of reduction loop in the algorithm. It represents the computation body. For more details, see "Data reordering, bit reversal, and in-place algorithms" section in https://en.wikipedia.org/wiki/Cooley%E2%80%93Tukey_FFT_algorithm

The size of innermost dim is expected to be a power of 2.

It is optional to carry coefficient tensors/buffers as inputs. In this context, they will be the second and third inputs.

Traits: AttrSizedOperandSegments, SingleBlockImplicitTerminator<::mlir::iree_compiler::IREE::LinalgExt::YieldOp>, SingleBlock

Interfaces: DestinationStyleOpInterface, LinalgExtInterface, MemoryEffectOpInterface, ReifyRankedShapedTypeOpInterface, TilingInterface

Operands:link
Operand Description
inputs variadic of any type
outputs variadic of shaped of any type values
Results:link
Result Description
results variadic of ranked tensor of any type values

iree_linalg_ext.reverse (LinalgExt::ReverseOp)link

Reverse operator

Syntax:

operation ::= `iree_linalg_ext.reverse` attr-dict `dimensions` `(` $dimensions `)`
              (`ins` `(` $inputs^ `:` type($inputs) `)`)?
              (`outs` `(` $outputs^ `:` type($outputs) `)`)?
              (`:` type($results)^)?

A temporary solution for lowering reverse ops into IREE, allowing IREE to tile and distribute them. }

Traits: AttrSizedOperandSegments, SingleBlockImplicitTerminator<::mlir::iree_compiler::IREE::LinalgExt::YieldOp>, SingleBlock

Interfaces: DestinationStyleOpInterface, LinalgExtInterface, LinalgExtOp, MemoryEffectOpInterface, ReifyRankedShapedTypeOpInterface, TilingInterface

Attributes:link
AttributeMLIR TypeDescription
dimensions::mlir::DenseIntElementsAttr64-bit signless integer elements attribute
Operands:link
Operand Description
inputs variadic of shaped of any type values
outputs variadic of shaped of any type values
Results:link
Result Description
results variadic of ranked tensor of any type values

iree_linalg_ext.scan (LinalgExt::ScanOp)link

Scan operator

Syntax:

operation ::= `iree_linalg_ext.scan` attr-dict
              `dimension` `(` $dimension `)`
              `inclusive` `(` $inclusive `)`
              `ins` `(` $inputs `:` type($inputs) `)`
              `outs` `(` $outputs `:` type($outputs) `)`
              $region (`->` type($results)^)?

Computes the inclusive/exclusive scan along a given dimension.

Traits: AttrSizedOperandSegments, SingleBlockImplicitTerminator<::mlir::iree_compiler::IREE::LinalgExt::YieldOp>, SingleBlock

Interfaces: DestinationStyleOpInterface, LinalgExtInterface, MemoryEffectOpInterface, ReifyRankedShapedTypeOpInterface, TilingInterface

Attributes:link
AttributeMLIR TypeDescription
dimension::mlir::IntegerAttr64-bit signless integer attribute
inclusive::mlir::BoolAttrbool attribute
Operands:link
Operand Description
inputs variadic of shaped of any type values
outputs variadic of shaped of any type values
Results:link
Result Description
results variadic of ranked tensor of any type values

iree_linalg_ext.scatter (LinalgExt::ScatterOp)link

Scatter operator

Syntax:

operation ::= `iree_linalg_ext.scatter` attr-dict `dimension_map` `=` $dimension_map
              `unique_indices` `(` $unique_indices `)`
              (`ins` `(` $inputs^ `:` type($inputs) `)`)?
              `outs` `(` $outputs `:` type($outputs) `)`
              $region (`->` type($results)^)?

Based on XLA operation semantics, takes two inputs (update and indices) and outputs value (original). The operation updates the value at the slices specified by indices by combining the current value with the value in updates using the computation specified in region. The region specifies a binary operation of signature (T, T) -> T, where T is the element-type of updates (and original). The first argument correspond the value to be updated (i.e. from updates), and the second the current value (i.e. value from original).

The indices is a 2D tensor/memref type. The first dim is the number of updates, and the second dim is index depth. The index depth should always be static.

The first dim of updates and indices is identical, since they represent the number of updates.

The rank of the original/result is at least index_depth + rank(%updates) - 1. The first index_depth indices are derived from indices and the shape of update value has the last rank(%original) - index_depth values match %(originals) last dimensions, with the previous dims extending from the index offsets.

The dimension_map attributes describes which index value maps to which dimension in the destionation. It cannot contain duplicate values, must have as many entries as index depth, and values must be within the rank of the destination.

The unique_indices attribute carries the information whether all the indices are unique. If there are repeated indices, the first iteration loop will be marked as reduction.

The shapes definition follows tensorflow operations execept that it force batch dims to be 1D. See more information in https://www.tensorflow.org/api_docs/python/tf/tensor_scatter_nd_update

Traits: AttrSizedOperandSegments, SingleBlockImplicitTerminator<::mlir::iree_compiler::IREE::LinalgExt::YieldOp>, SingleBlock

Interfaces: DestinationStyleOpInterface, LinalgExtInterface, MemoryEffectOpInterface, ReifyRankedShapedTypeOpInterface, TilingInterface

Attributes:link
AttributeMLIR TypeDescription
dimension_map::mlir::DenseI64ArrayAttri64 dense array attribute
unique_indices::mlir::BoolAttrbool attribute
Operands:link
Operand Description
inputs variadic of ranked tensor or memref of any type values
outputs variadic of ranked tensor or memref of any type values
Results:link
Result Description
results variadic of ranked tensor of any type values

iree_linalg_ext.sort (LinalgExt::SortOp)link

Sort operator

Syntax:

operation ::= `iree_linalg_ext.sort` attr-dict
              `dimension` `(` $dimension `)`
              (`ins` `(` $inputs^ `:` type($inputs) `)`)?
              `outs` `(` $outputs `:` type($outputs) `)`
              $region (`->` type($results)^)?

Based on XLA operation semantics, sorts the given operands at the given dimension with the given comparator.

See https://www.tensorflow.org/xla/operation_semantics#sort.

Traits: AttrSizedOperandSegments, SingleBlockImplicitTerminator<::mlir::iree_compiler::IREE::LinalgExt::YieldOp>, SingleBlock

Interfaces: DestinationStyleOpInterface, LinalgExtInterface, MemoryEffectOpInterface, ReifyRankedShapedTypeOpInterface, TilingInterface

Attributes:link
AttributeMLIR TypeDescription
dimension::mlir::IntegerAttr64-bit signless integer attribute
Operands:link
Operand Description
inputs variadic of any type
outputs variadic of shaped of any type values
Results:link
Result Description
results variadic of ranked tensor of any type values

iree_linalg_ext.topk (LinalgExt::TopkOp)link

Top-K operator

Syntax:

operation ::= `iree_linalg_ext.topk` attr-dict
              `dimension` `(` $dimension `)`
              `ins` `(` $inputs `:` type($inputs) `)`
              `outs` `(` $outputs `:` type($outputs) `)`
              $region (`->` type($results)^)?

A Top-K operation for N-D tensors. Reduces the target dimension from the input size N down to K elements based on the supplied binary region.

Accepts an N-D tensor input consisting of values and an optioanl N-D tensor for indices of those values (i32 type). If input indices aren't provided, the index mapping is inferred based on the k dim. Both input values/indices tensors and output values/indicies tensors must have the same shape. Top-K is computed along the target dimension (from dimension()). Returns two output tensors of values and the indicies of Top-K results. The output dimensions must match the input save for the dimension that is reduced to K results.

Region accepts lhs=[next N input] and rhs=[exiting K output] and yeilds an i1. If true, the two values are swapped: - For Top-K compoarision: > - For Min-K comparision: < Note: when the two values are equal, the first occurence is always selected.

Traits: AttrSizedOperandSegments, SingleBlockImplicitTerminator<::mlir::iree_compiler::IREE::LinalgExt::YieldOp>, SingleBlock

Interfaces: DestinationStyleOpInterface, LinalgExtInterface, LinalgExtOp, MemoryEffectOpInterface, ReifyRankedShapedTypeOpInterface, TilingInterface

Attributes:link
AttributeMLIR TypeDescription
dimension::mlir::IntegerAttr64-bit signless integer attribute
Operands:link
Operand Description
inputs variadic of shaped of any type values
outputs variadic of shaped of any type values
Results:link
Result Description
results variadic of ranked tensor of any type values

Utility opslink

iree_linalg_ext.yield (LinalgExt::YieldOp)link

LinalgExt yield op

Syntax:

operation ::= `iree_linalg_ext.yield` attr-dict ($operands^ `:` type($operands))?

iree_linalg_ext.yield is a special terminator operation for blocks inside regions in iree_linalg_ext ops.

Traits: AlwaysSpeculatableImplTrait, ReturnLike, Terminator

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), RegionBranchTerminatorOpInterface

Effects: MemoryEffects::Effect{}

Operands:link
Operand Description
operands variadic of any type

Winograd opslink

iree_linalg_ext.winograd.filter_transform (LinalgExt::WinogradFilterTransformOp)link

Winograd Filter Transform operator

Syntax:

operation ::= `iree_linalg_ext.winograd.filter_transform` attr-dict
              `output_tile_size` `(` $output_tile_size `)`
              `kernel_size` `(` $kernel_size `)`
              `kernel_dimensions` `(` $kernel_dimensions `)`
              `ins` `(` $inputs `:` type($inputs) `)`
              `outs` `(` $outputs `:` type($outputs) `)`
              (`->` type($result)^)?

This operator is part of the first step in converting a convolution to its Winograd equivalent. Given a tile of a convolution filter (F), this operator computes matmul(G, matmul(F, transpose(B))). The filter tile is assumed to be the full m x m convolutional kernel, and the result of the transformation on this tile is a square with each side of size m + r - 1, where the output tile size is r x r. G is a constant 2-d matrix of shape (m + r - 1) x m. The input to the operator is a filter of shape (H, W, C, F) or (F, C, H, W) and the output is an operator of shape (m + r - 1, m + r - 1, C, F). The result of this operator is first collapsed and then fed to a batch matmul op.

Traits: AttrSizedOperandSegments, SingleBlockImplicitTerminator<::mlir::iree_compiler::IREE::LinalgExt::YieldOp>, SingleBlock

Interfaces: DestinationStyleOpInterface, LinalgExtInterface, MemoryEffectOpInterface, ReifyRankedShapedTypeOpInterface, TilingInterface

Attributes:link
AttributeMLIR TypeDescription
output_tile_size::mlir::IntegerAttr64-bit signless integer attribute
kernel_size::mlir::IntegerAttr64-bit signless integer attribute
kernel_dimensions::mlir::DenseI64ArrayAttri64 dense array attribute
Operands:link
Operand Description
inputs variadic of shaped of any type values
outputs variadic of shaped of any type values
Results:link
Result Description
result variadic of ranked tensor of any type values

iree_linalg_ext.winograd.input_transform (LinalgExt::WinogradInputTransformOp)link

Winograd Input Transform operator

Syntax:

operation ::= `iree_linalg_ext.winograd.input_transform` attr-dict
              `output_tile_size` `(` $output_tile_size `)`
              `kernel_size` `(` $kernel_size `)`
              `image_dimensions` `(` $image_dimensions `)`
              `ins` `(` $inputs `:` type($inputs) `)`
              `outs` `(` $outputs `:` type($outputs) `)`
              (`->` type($result)^)?

This operator is part of the first step in converting a convolution to its Winograd equivalent. Given a tile of an input image (I), this operator computes matmul(tranpose(B), matmul(I, B)). The input tile is assumed to be square with each side of size m + r - 1, where the convolutional kernel is m x m and the output tile size is r x r. B is a constant 2-d square matrix of the same shape as the input tile I. The input to the operator is an image of shape (N, H, W, C) or (N, C, H, W) and the output is an operator of shape (m + r - 1, m + r - 1, N, H', W', C) where H' = ceil((H - m + 1)/r) and W' = ceil((W - m + 1)/r). The result of this operator is first collapsed and then fed to a batch matmul op.

Traits: AttrSizedOperandSegments, SingleBlockImplicitTerminator<::mlir::iree_compiler::IREE::LinalgExt::YieldOp>, SingleBlock

Interfaces: DestinationStyleOpInterface, LinalgExtInterface, MemoryEffectOpInterface, ReifyRankedShapedTypeOpInterface, TilingInterface

Attributes:link
AttributeMLIR TypeDescription
output_tile_size::mlir::IntegerAttr64-bit signless integer attribute
kernel_size::mlir::IntegerAttr64-bit signless integer attribute
image_dimensions::mlir::DenseI64ArrayAttri64 dense array attribute
Operands:link
Operand Description
inputs variadic of shaped of any type values
outputs variadic of shaped of any type values
Results:link
Result Description
result variadic of ranked tensor of any type values

iree_linalg_ext.winograd.output_transform (LinalgExt::WinogradOutputTransformOp)link

Winograd Output Transform operator

Syntax:

operation ::= `iree_linalg_ext.winograd.output_transform` attr-dict
              `output_tile_size` `(` $output_tile_size `)`
              `kernel_size` `(` $kernel_size `)`
              `image_dimensions` `(` $image_dimensions `)`
              `ins` `(` $inputs `:` type($inputs) `)`
              `outs` `(` $outputs `:` type($outputs) `)`
              (`->` type($result)^)?

This operator is the last transform in converting a convolution to its Winograd equivalent. After convolution in the Winograd domain (which turns into an elementwise product for a single channel and batch matrix multiplication for many channels), this operator converts the output back into the original domain. Given a tile of the output (O) in the Winograd domain, this operator computes matmul(transpose(A), matmul(O, A)). The output tile is square with each side of size m + r - 1, where the convolutional kernel is m x m and the output tile size is r x r. A is a constant 2-d matrix of shape (m + r - 1) x r. The input to the operator is a tensor of shape (m + r - 1, m + r - 1, N, H', W', C) and the output is a tensor of shape (N, H, W, C) or (N, C, H, W) where H = r H' and W = r W'. This operator is followed by a tensor.extract_slice which extracts only the non-padded part of the output.

Traits: AttrSizedOperandSegments, SingleBlockImplicitTerminator<::mlir::iree_compiler::IREE::LinalgExt::YieldOp>, SingleBlock

Interfaces: DestinationStyleOpInterface, LinalgExtInterface, MemoryEffectOpInterface, ReifyRankedShapedTypeOpInterface, TilingInterface

Attributes:link
AttributeMLIR TypeDescription
output_tile_size::mlir::IntegerAttr64-bit signless integer attribute
kernel_size::mlir::IntegerAttr64-bit signless integer attribute
image_dimensions::mlir::DenseI64ArrayAttri64 dense array attribute
Operands:link
Operand Description
inputs variadic of shaped of any type values
outputs variadic of shaped of any type values
Results:link
Result Description
result variadic of ranked tensor of any type values

Attributeslink

EncodingAttrlink

information to decide how to data-tile a tensor

Syntax:

#iree_linalg_ext.encoding<
  EncodingRoleAttr,   # role
  ArrayAttr,   # element_types
  TypeAttr,   # original_type
  IntegerAttr,   # matmul_narrow_M
  IntegerAttr,   # matmul_narrow_N
  ArrayAttr,   # user_indexing_maps
  DenseArrayAttr   # round_dims_to
>

This attribute describes the change in the layout for a given tensor to execute subsequent operations on the tiled layout. The encoding serves as a way to represent the change in the way the data is laid out in memory without changing the logical rank/extent of the tensor itself. When required, the encoding can be used to explicitly manifest the layout change through operations like pack/unpack.

Parameters:link
Parameter C++ type Description
role EncodingRoleAttr role of this tensor as an operand
element_types ArrayAttr element types of the user's operands
original_type TypeAttr type of the original tensor type before padding
matmul_narrow_M IntegerAttr optional M narrow dimension size (only for contraction op user_indexing_maps)
matmul_narrow_N IntegerAttr optional N narrow dimension size (only for contraction op user_indexing_maps)
user_indexing_maps ArrayAttr Indexing maps of the operation using this tensor
round_dims_to DenseArrayAttr Values for padding M,N,K dimensions

EncodingRoleAttrlink

Describes the role of the tensor as an operand or a result of an operation.

Syntax:

#iree_linalg_ext.role<
  ::mlir::iree_compiler::IREE::LinalgExt::EncodingRole   # value
>

Enum cases: * LHS (LHS) * RHS (RHS) * RESULT (RESULT)

Parameters:link
Parameter C++ type Description
value ::mlir::iree_compiler::IREE::LinalgExt::EncodingRole an enum of type EncodingRole