Skip to content

'iree_gpu' Dialectlink

A dialect representing attributes used by GPU focused IREE code generation.

This dialect provides operations and attributes to aid in code generation for GPU targets. The functionality in this dialect can be hardware specific, but is intended to be independent of the lowering target. Late lowerings to SPIR-V/LLVM are handled separately.

Operationslink

iree_gpu.multi_mma (GPU::MultiMmaOp)link

Models a contraction of multiple mma operations

Syntax:

operation ::= `iree_gpu.multi_mma` $lhs `,` $rhs `,` $acc attr-dict
              `:` type($lhs) `,` type($rhs) `into` type($acc)

Computes the sum of inner MMA operations along a set of outer dimensions. Logically matches closely with a vector.contraction operation, however the combiner type is a specific intrinsic rather than a generic combiner type.

Similar to vector.contraction, an iterator type attribute list must be specified, where each element of the list represents an iterator over one of the outer dimensions. Iteration of inner dimensions is defined solely by the intrinsic and may be opaque.

An indexing map attribute list must be specified with an entry for lhs, rhs and acc arguments. An indexing map attribute specifies a mapping from each outer loop iterator in the iterator type list, to each dimension of each operand.

The combiner type is defined by the intrinsic.

Example:

#contraction_accesses = [
 affine_map<(i, j, k) -> (i, k)>,
 affine_map<(i, j, k) -> (k, j)>,
 affine_map<(i, j, k) -> (i, j)>
]
#contraction_trait = {
  indexing_maps = #contraction_accesses,
  iterator_types = ["parallel", "parallel", "reduction"],
  kind = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>
}
%3 = iree_gpu.multi_mma %0, %1, %2 #contraction_trait
  : vector<2x3x4xf16>, vector<3x5x4xf16> into vector<2x5x4xf32>

// Takes tensors as well, however the inner dimensions must always be
// static.
%7 = iree_gpu.multi_mma %4, %5, %6 #contraction_trait
  : tensor<?x?x4xf16>, tensor<?x?x4xf16> into tensor<?x?x4xf32>

The example above can be logically lowered directly to loops like this (ignoring type conversions from tensor to vector needed for the mfma).

%outer_m = tensor.dim %6, %c0 : index
%outer_n = tensor.dim %6, %c1 : index
%outer_k = tensor.dim %4, %c1 : index
%7 = scf.for %i = %c0 to %outer_m iter_args(%arg0 = %6) {
  %8 = scf.for %j = %c0 to %outer_n iter_args(%arg1 = %arg0) {
    %9 = scf.for %k = %c0 to %outer_k iter_args(%arg2 = %arg1) {
      %lhs = tensor.extract_slice %4 [%i, %k, 0] [1, 1, 4] [1, 1, 1] : tensor<4xf16>
      %rhs = tensor.extract_slice %5 [%k, %j, 0] [1, 1, 4] [1, 1, 1] : tensor<4xf16>
      %acc = tensor.extract_slice %arg2 [%i, %j, 0] [1, 1, 4] [1, 1, 1] : tensor<4xf32>
      %res = amdgpu.mfma %lhs, %rhs, %acc : tensor<4xf32>
      %ret = tensor.insert_slice %acc into %arg2 [%i, %j, 0] [1, 1, 4] [1, 1, 1] : tensor<?x?x4xf32>
      scf.yield %ret : tensor<?x?x4xf32>
    }
    scf.yield %9 : tensor<?x?x4xf32>
  }
  scf.yield %8 : tensor<?x?x4xf32>
}

Or alternatively unrolled to a single intrinsic when operation on vectors.

#contraction_accesses = [
 affine_map<() -> ()>,
 affine_map<() -> ()>,
 affine_map<() -> ()>
]
#contraction_trait = {
  indexing_maps = #contraction_accesses,
  iterator_types = [],
  kind = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>
}
%3 = iree_gpu.multi_mma %0, %1, %2 #contraction_trait
  : vector<4xf16>, vector<4xf16> into vector<4xf32>

This operation can represent an intrinsic both in subgroup/warp and distributed (thread) abstractions through the intrinsic attribute interface. It does so semi-opaquely by including optional permutations of each MMA fragment with respect to the "canonical" MNK row major matrix multiply.

Since the canonical dimensionality of the inner dimensions are somewhat intrinsic specific, verification of this op requires only that element counts of the inner dimensions match the intrinsic.

For example, an MMT product of inner dimensions with warp semantics can be represented with the following. Permutations are only allowed for ops with subgroup semantics and must be resolved before distribution.

#contraction_accesses = [
 affine_map<(i, j, k) -> (i, k)>,
 affine_map<(i, j, k) -> (k, j)>,
 affine_map<(i, j, k) -> (i, j)>
]
#contraction_trait = {
  indexing_maps = #contraction_accesses,
  iterator_types = ["parallel", "parallel", "reduction"],
  kind = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>,
  rhs_permutation = [1, 0]
}
%7 = iree_gpu.multi_mma %4, %5, %6 #contraction_trait
  : tensor<?x?x16x16xf16>, tensor<?x?x16x16xf16> into tensor<?x?x16x16xf32>

Motivation, Design Choices, and Pitfallslink

The idea behind this operation is to decouple the layout setting/tiling required to target certain intrinsics from the lowering to them. Because typically tiling of this sort happens on tensor operands, however the target intrinsics operate on vectors, we use this operation to bridge the gap. The choice for a shared operation is intended to ease the lowering process and allow for different transformations at different stages of the pipeline without needing to essentially clone this op.

The choice to let the inner dimensions required to compute the intrinsic be implicit based on the indexing maps was made to make this operation easier to generate and to skip the need for type conversion ops. However this comes at the expense of ease of verification for the operation. It is also implicitly linked to a lane-level parent scf.forall operation.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), VectorUnrollOpInterface

Effects: MemoryEffects::Effect{}

Attributes:link
AttributeMLIR TypeDescription
indexing_maps::mlir::ArrayAttrarray attribute
iterator_types::mlir::ArrayAttrIterator type should be an enum.
kindIREE::GPU::MmaInterfaceAttrbuffer-like constant attribute values
lhs_permutation::mlir::DenseI64ArrayAttri64 dense array attribute
rhs_permutation::mlir::DenseI64ArrayAttri64 dense array attribute
acc_permutation::mlir::DenseI64ArrayAttri64 dense array attribute
Operands:link
Operand Description
lhs ranked tensor or vector of any type values
rhs ranked tensor or vector of any type values
acc ranked tensor or vector of any type values
Results:link
Result Description
result ranked tensor or vector of any type values

iree_gpu.shuffle_tensor (GPU::ShuffleTensorOp)link

Shuffles a private tensor across a shared allocation

Syntax:

operation ::= `iree_gpu.shuffle_tensor` $source ``
              custom<DynamicIndexList>($offsets, $static_offsets)
              custom<DynamicIndexList>($sizes, $static_sizes)
              custom<DynamicIndexList>($strides, $static_strides)
              `to` $dest $region attr-dict
              `:` type($source) `->` type($dest) `->` type($result)

This op is designed to represent a shuffle of private tensor data collectively held across a set of workers. This operation naturally arises when combining the regions of producer-consumer scf.forall operations that share a mapping type and worker count.

For example, consider the following pair of parallel loops.

  %0 = scf.forall (%idy, %idx) in (2, 32) shared_outs(%init = %empty) -> (tensor<4x128xf32>) {
    %in = ...
    %2 = affine.apply #affine_map<(d0) -> (d0 * 2)> (%idy)
    %3 = affine.apply #affine_map<(d0) -> (d0 * 4)> (%idx)
    scf.forall.in_parallel {
      tensor.parallel_insert_slice %in into %init[%2, %3] [2, 4] [1, 1]
        : tensor<2x4xf32> into tensor<4x128xf32>
    }
  } {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
  %1 = scf.forall (%idy, %idx) in (8, 8) -> (tensor<128x128xf32>) {
    %4 = affine.apply #affine_map<(d0) -> (d0 * 16)> (%idx)
    %extracted_slice = tensor.extract_slice %0[0, %4] [4, 16] [1, 1]
      : tensor<4x128xf32> to tensor<4x16xf32>
    ...
  } {mapping = [#gpu.thread<y>, #gpu.thread<x>]}

Because these loops share the same worker type and total count, the bodies of these two loops can be merged with a barrier and a shuffle where the boundary of the loops currently is.

  %0 = scf.forall (%idy, %idx) in (8, 8) -> (tensor<4x128xf32>) {
    %ids = affine.delinearize_index %idy * 8 + %idx to (2, 32) : index
    %in = ...
    %2 = affine.apply #affine_map<(d0) -> (d0 * 2)> (%ids#0)
    %3 = affine.apply #affine_map<(d0) -> (d0 * 4)> (%ids#1)
    %4 = affine.apply #affine_map<(d0) -> (d0 * 16)> (%idx)
    %slice = iree_gpu.shuffle_tensor %in[%2, %3] [2, 4] [1, 1] to %empty {
    ^bb0(%intermediate: tensor<4x128xf32>):
      %slice = tensor.extract_slice %intermediate[0, %4] [4, 16] [1, 1] : tensor<4x128xf32> to tensor<4x16xf32>
      iree_gpu.yield %slice : tensor<4x16xf32>
    } : tensor<2x4xf32> -> tensor<4x128xf32> -> tensor<4x16xf32>
    ...
  } {mapping = [#gpu.thread<y>, #gpu.thread<x>]}

A shuffle can be lowered to a shared allocation with a write of the source slice, a barrier, inlining the body of the shuffle op (the read), and then a barrier to synchronize all workers on the result of the read. Note that it is undefined behavior if there are any conflicting writes to the intermediate. Also to execute the barrier, any lowerings of the enclosing scf.forall to serial loops is invalid. In other words, the lowerings must provide the number of workers requested by the loop.

This op takes an input |source| tensor to represent the slice held by this worker before the shuffle, an intermediate tensor |dest| that all workers insert into, and performs a synchronized read from that intermediate tensor.

It is undefined behavior if the source tensor is out of bounds of the intermediate allocation.

Movtivation and Intended Use Cases:

The primary way this op is generated is when fusing parallel loops with tensor results. This operation helps to make lowerings more progressive and flexible. - Rather than lowering straight to vector ops for the reads/writes for the shuffle, this allows separating out the vectorization of the shared memory accesses from earlier tiling steps. - Lowering directly to an alloc + reads and writes breaks the dependency chain making transformations like barrier placement and pipelining potentially more difficult. - Allows the option of non-vector based lowering paths.

Traits: AlwaysSpeculatableImplTrait, AttrSizedOperandSegments, SingleBlockImplicitTerminator<mlir::iree_compiler::IREE::GPU::YieldOp>, SingleBlock

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes:link
AttributeMLIR TypeDescription
static_offsets::mlir::DenseI64ArrayAttri64 dense array attribute
static_sizes::mlir::DenseI64ArrayAttri64 dense array attribute
static_strides::mlir::DenseI64ArrayAttri64 dense array attribute
Operands:link
Operand Description
source ranked tensor of any type values
offsets variadic of index
sizes variadic of index
strides variadic of index
dest ranked tensor of any type values
Results:link
Result Description
result ranked tensor or vector of any type values

iree_gpu.value_barrier (GPU::ValueBarrierOp)link

Shuffles a private tensor across a shared allocation

Syntax:

operation ::= `iree_gpu.value_barrier` $input attr-dict `:` type($result)

This operation acts as a barrier on a value semantic SSA value (tensor or vector). It takes a single operand and produces a value equivalent to the input. This does not have copy and/or data movement semantics and simply represents a barrier on all writes in the tensor case, and a barrier until all threads acquire the input vector in the vector case.

This operation is a no-op when not present in a parallel context. This operation is pure as it only requires synchronization for the value it produces.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Operands:link
Operand Description
input ranked tensor or vector of any type values
Results:link
Result Description
result ranked tensor or vector of any type values

iree_gpu.yield (GPU::YieldOp)link

Yield a value from a region

Syntax:

operation ::= `iree_gpu.yield` $value attr-dict `:` type($value)

This operation is used to yield a single value from a within a region.

Traits: AlwaysSpeculatableImplTrait, HasParent<::mlir::iree_compiler::IREE::GPU::ShuffleTensorOp>, ReturnLike, Terminator

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), RegionBranchTerminatorOpInterface

Effects: MemoryEffects::Effect{}

Operands:link
Operand Description
value any type

Attributeslink

ComputeBitwidthsAttrlink

Supported bitwidths for compute

Syntax:

#iree_gpu.compute_bitwidths<
  ::mlir::iree_compiler::IREE::GPU::ComputeBitwidths   # value
>

Enum cases: * fp64 (FP64) * fp32 (FP32) * fp16 (FP16) * int64 (Int64) * int32 (Int32) * int16 (Int16) * int8 (Int8)

Parameters:link
Parameter C++ type Description
value ::mlir::iree_compiler::IREE::GPU::ComputeBitwidths an enum of type ComputeBitwidths

DerivedThreadConfigAttrlink

drive lowering of an operation by deriving thread distribution when needed.

Syntax: #iree_gpu.derived_thread_config

Lowering config for a single thread tiling level that is inferred after previous (often reduction) levels of tile + fuse. This is intended for fused operations where it is much easier to compute the tile sizes to use after previous levels of tile + fuse, rather than trying to pre-propagate tiling configs.

DotProductOpsAttrlink

Supported dot product ops

Syntax:

#iree_gpu.dotproduct_ops<
  ::mlir::iree_compiler::IREE::GPU::DotProductOps   # value
>

Enum cases: * none (None) * dp4xi8toi32 (DP4xI8ToI32)

Parameters:link
Parameter C++ type Description
value ::mlir::iree_compiler::IREE::GPU::DotProductOps an enum of type DotProductOps

IteratorTypeAttrlink

Iterator type

Syntax:

#iree_gpu.iterator_type<
  ::mlir::utils::IteratorType   # value
>

Enum cases: * parallel (parallel) * reduction (reduction)

Parameters:link
Parameter C++ type Description
value ::mlir::utils::IteratorType an enum of type IteratorType

LaneIdAttrlink

Syntax:

#iree_gpu.lane_id<
  int64_t   # dim
>

An attribute for mapping scf.forall ops to subgroup lanes.

Parameters:link
Parameter C++ type Description
dim int64_t

LoweringConfigAttrlink

drive lowering of an operation for gpu compilation.

Syntax:

#iree_gpu.lowering_config<
  DictionaryAttr   # attributes
>

GPU specific implementation of a lowering config. This carries just a dictionary attribute to store any relevant fields. This is the simplest form of a lowering config, offering flexibility at the cost of structure.

Parameters:link
Parameter C++ type Description
attributes DictionaryAttr The configured fields, including tiling levels

MMAAttrlink

Attribute describing a particular shape of matrix-multiply and accumulate instruction. Abstractly, all attributes of this type represent the following unit of arithmetic for matrices A, B, and C.

  C += A x B

Where the shape of matrix A is [m, k], B is [k, n], and C is [m, n]. This intentionally leaves the layout information abstract and uses interface methods to materialize layout information only when needed. The shape of the mma intrinsic is stored explicitly here as that information is queried frequently.

The element types for this particular mma intrinsic are |aType|, |bType|, and |cType| for matrices A, B, and C respectively.

link

This mma variant describes configurations for MMA ops. The |intrinsic| field specifies which particular MMA intrinsic this refers to, with each intrinsic implicating a specific MNK shape and operand types. The intrinsic enum name describes these fields as

MxNxK

Where the element type for the A and B matrices are both InputType.

Parameters:link
Parameter C++ type Description
intrinsic MMAIntrinsicAttr
mSize int64_t
nSize int64_t
kSize int64_t
aType ::mlir::Type
bType ::mlir::Type
cType ::mlir::Type

MMAIntrinsicAttrlink

Descriptor for different MMA intrinsics

Syntax:

#iree_gpu.mma_intrinsic<
  ::mlir::iree_compiler::IREE::GPU::MMAIntrinsic   # value
>

Enum cases: * MFMA_F16_16x16x16_F32 (MFMA_F16_16x16x16_F32) * MFMA_F16_32x32x8_F32 (MFMA_F16_32x32x8_F32) * WMMA_F16_16x16x16_F32 (WMMA_F16_16x16x16_F32) * WMMA_F16_16x16x16_F16 (WMMA_F16_16x16x16_F16)

Parameters:link
Parameter C++ type Description
value ::mlir::iree_compiler::IREE::GPU::MMAIntrinsic an enum of type MMAIntrinsic

MMAOpsArrayAttrlink

Syntax:

#iree_gpu.mma_ops<
  ::llvm::ArrayRef<MMAAttr>   # value
>
Parameters:link
Parameter C++ type Description
value ::llvm::ArrayRef<MMAAttr>

MMAScheduleAttrlink

Syntax:

#iree_gpu.mma_schedule<
  ::mlir::iree_compiler::IREE::GPU::MmaInterfaceAttr,   # intrinsic
  int64_t,   # subgroup_m_count
  int64_t   # subgroup_n_count
>

A schedule of MMA intrinsic instruction and various levels of tile sizes to solve a specific contraction problem.

Parameters:link
Parameter C++ type Description
intrinsic ::mlir::iree_compiler::IREE::GPU::MmaInterfaceAttr
subgroup_m_count int64_t
subgroup_n_count int64_t

StorageBitwidthsAttrlink

Supported bitwidths for storage

Syntax:

#iree_gpu.storage_bitwidths<
  ::mlir::iree_compiler::IREE::GPU::StorageBitwidths   # value
>

Enum cases: * b64 (B64) * b32 (B32) * b16 (B16) * b8 (B8)

Parameters:link
Parameter C++ type Description
value ::mlir::iree_compiler::IREE::GPU::StorageBitwidths an enum of type StorageBitwidths

SubgroupOpsAttrlink

Supported subgroup ops

Syntax:

#iree_gpu.subgroup_ops<
  ::mlir::iree_compiler::IREE::GPU::SubgroupOps   # value
>

Enum cases: * none (None) * shuffle (Shuffle) * arithmetic (Arithmetic)

Parameters:link
Parameter C++ type Description
value ::mlir::iree_compiler::IREE::GPU::SubgroupOps an enum of type SubgroupOps

TargetAttrlink

Full GPU target attribute

Syntax:

#iree_gpu.target<
  ::llvm::StringRef,   # arch
  ::llvm::StringRef,   # features
  TargetWgpAttr,   # wgp
  TargetChipAttr   # chip
>

This attributes describes a full GPU target. It contains a few fields: * The canonical target architecture for compilation, e.g., sm_80 for cuda, gfx942 for hip * A TargetWgpAttr describing the GPU features and limits in a single GPU workgroup processor (WGP), that is, AMD compute unit or NVIDIA streaming multiprocessor * An optional TargetChipAttr describing GPU features for the final chip or product, e.g., wgp count

Parameters:link
Parameter C++ type Description
arch ::llvm::StringRef target architecture
features ::llvm::StringRef target features
wgp TargetWgpAttr
chip TargetChipAttr

TargetChipAttrlink

Chip level target description

Syntax:

#iree_gpu.target_chip<
  uint32_t,   # wgp_count
  DictionaryAttr   # extra
>

This attribute contains hardware features/limits at a single GPU chip level. Here a GPU chip means the hardware functionality scope where the whole software compute grid is scheduled onto. A chip typically contains many AMD compute units or NVIDIA streaming multiprocessors; it's the final SKU.

Parameters:link
Parameter C++ type Description
wgp_count uint32_t
extra DictionaryAttr

TargetWgpAttrlink

Workgroup processor level target description

Syntax:

#iree_gpu.target_wgp<
  ComputeBitwidthsAttr,   # compute
  StorageBitwidthsAttr,   # storage
  SubgroupOpsAttr,   # subgroup
  DotProductOpsAttr,   # dot
  MMAOpsArrayAttr,   # mma
  DenseI32ArrayAttr,   # subgroup_size_choices
  DenseI32ArrayAttr,   # max_workgroup_sizes
  uint32_t,   # max_thread_count_per_workgroup
  uint32_t,   # max_workgroup_memory_bytes
  DictionaryAttr   # extra
>

This attribute contains hardware features/limits at a single GPU workgroup processor (WGP) level. Here a GPU workgroup processor means the basic hardware functionality unit where a software workgroup is scheduled onto; that is, a compute unit for AMD GPUs or a streaming multiprocessor for NVIDIA GPUs.

Parameters:link
Parameter C++ type Description
compute ComputeBitwidthsAttr
storage StorageBitwidthsAttr
subgroup SubgroupOpsAttr
dot DotProductOpsAttr
mma MMAOpsArrayAttr
subgroup_size_choices DenseI32ArrayAttr
max_workgroup_sizes DenseI32ArrayAttr
max_thread_count_per_workgroup uint32_t
max_workgroup_memory_bytes uint32_t
extra DictionaryAttr