'iree_gpu' Dialectlink
A dialect representing attributes used by GPU focused IREE code generation.
This dialect provides operations and attributes to aid in code generation for GPU targets. The functionality in this dialect can be hardware specific, but is intended to be independent of the lowering target. Late lowerings to SPIR-V/LLVM are handled separately.
Operationslink
iree_gpu.multi_mma
(GPU::MultiMmaOp)link
Models a contraction of multiple mma operations
Syntax:
operation ::= `iree_gpu.multi_mma` $lhs `,` $rhs `,` $acc attr-dict
`:` type($lhs) `,` type($rhs) `into` type($acc)
Computes the sum of inner MMA operations along a set of outer dimensions.
Logically matches closely with a vector.contraction
operation, however
the combiner type is a specific intrinsic rather than a generic combiner
type.
Similar to vector.contraction
, an iterator type attribute list must be
specified, where each element of the list represents an iterator over one
of the outer dimensions. Iteration of inner dimensions is defined solely by
the intrinsic and may be opaque.
An indexing map attribute list must be specified with an entry for lhs, rhs and acc arguments. An indexing map attribute specifies a mapping from each outer loop iterator in the iterator type list, to each dimension of each operand.
The combiner type is defined by the intrinsic.
Example:
#contraction_accesses = [
affine_map<(i, j, k) -> (i, k)>,
affine_map<(i, j, k) -> (k, j)>,
affine_map<(i, j, k) -> (i, j)>
]
#contraction_trait = {
indexing_maps = #contraction_accesses,
iterator_types = ["parallel", "parallel", "reduction"],
kind = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>
}
%3 = iree_gpu.multi_mma %0, %1, %2 #contraction_trait
: vector<2x3x4xf16>, vector<3x5x4xf16> into vector<2x5x4xf32>
// Takes tensors as well, however the inner dimensions must always be
// static.
%7 = iree_gpu.multi_mma %4, %5, %6 #contraction_trait
: tensor<?x?x4xf16>, tensor<?x?x4xf16> into tensor<?x?x4xf32>
The example above can be logically lowered directly to loops like this (ignoring type conversions from tensor to vector needed for the mfma).
%outer_m = tensor.dim %6, %c0 : index
%outer_n = tensor.dim %6, %c1 : index
%outer_k = tensor.dim %4, %c1 : index
%7 = scf.for %i = %c0 to %outer_m iter_args(%arg0 = %6) {
%8 = scf.for %j = %c0 to %outer_n iter_args(%arg1 = %arg0) {
%9 = scf.for %k = %c0 to %outer_k iter_args(%arg2 = %arg1) {
%lhs = tensor.extract_slice %4 [%i, %k, 0] [1, 1, 4] [1, 1, 1] : tensor<4xf16>
%rhs = tensor.extract_slice %5 [%k, %j, 0] [1, 1, 4] [1, 1, 1] : tensor<4xf16>
%acc = tensor.extract_slice %arg2 [%i, %j, 0] [1, 1, 4] [1, 1, 1] : tensor<4xf32>
%res = amdgpu.mfma %lhs, %rhs, %acc : tensor<4xf32>
%ret = tensor.insert_slice %acc into %arg2 [%i, %j, 0] [1, 1, 4] [1, 1, 1] : tensor<?x?x4xf32>
scf.yield %ret : tensor<?x?x4xf32>
}
scf.yield %9 : tensor<?x?x4xf32>
}
scf.yield %8 : tensor<?x?x4xf32>
}
Or alternatively unrolled to a single intrinsic when operation on vectors.
#contraction_accesses = [
affine_map<() -> ()>,
affine_map<() -> ()>,
affine_map<() -> ()>
]
#contraction_trait = {
indexing_maps = #contraction_accesses,
iterator_types = [],
kind = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>
}
%3 = iree_gpu.multi_mma %0, %1, %2 #contraction_trait
: vector<4xf16>, vector<4xf16> into vector<4xf32>
This operation can represent an intrinsic both in subgroup/warp and distributed (thread) abstractions through the intrinsic attribute interface. It does so semi-opaquely by including optional permutations of each MMA fragment with respect to the "canonical" MNK row major matrix multiply.
Since the canonical dimensionality of the inner dimensions are somewhat intrinsic specific, verification of this op requires only that element counts of the inner dimensions match the intrinsic.
For example, an MMT product of inner dimensions with warp semantics can be represented with the following. Permutations are only allowed for ops with subgroup semantics and must be resolved before distribution.
#contraction_accesses = [
affine_map<(i, j, k) -> (i, k)>,
affine_map<(i, j, k) -> (k, j)>,
affine_map<(i, j, k) -> (i, j)>
]
#contraction_trait = {
indexing_maps = #contraction_accesses,
iterator_types = ["parallel", "parallel", "reduction"],
kind = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>,
rhs_permutation = [1, 0]
}
%7 = iree_gpu.multi_mma %4, %5, %6 #contraction_trait
: tensor<?x?x16x16xf16>, tensor<?x?x16x16xf16> into tensor<?x?x16x16xf32>
Motivation, Design Choices, and Pitfallslink
The idea behind this operation is to decouple the layout setting/tiling required to target certain intrinsics from the lowering to them. Because typically tiling of this sort happens on tensor operands, however the target intrinsics operate on vectors, we use this operation to bridge the gap. The choice for a shared operation is intended to ease the lowering process and allow for different transformations at different stages of the pipeline without needing to essentially clone this op.
The choice to let the inner dimensions required to compute the intrinsic be
implicit based on the indexing maps was made to make this operation easier
to generate and to skip the need for type conversion ops. However this comes
at the expense of ease of verification for the operation. It is also
implicitly linked to a lane-level parent scf.forall
operation.
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
, VectorUnrollOpInterface
Effects: MemoryEffects::Effect{}
Attributes:link
Attribute | MLIR Type | Description |
---|---|---|
indexing_maps | ::mlir::ArrayAttr | array attribute |
iterator_types | ::mlir::ArrayAttr | Iterator type should be an enum. |
kind | IREE::GPU::MmaInterfaceAttr | buffer-like constant attribute values |
lhs_permutation | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
rhs_permutation | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
acc_permutation | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
Operands:link
Operand | Description |
---|---|
lhs |
ranked tensor or vector of any type values |
rhs |
ranked tensor or vector of any type values |
acc |
ranked tensor or vector of any type values |
Results:link
Result | Description |
---|---|
result |
ranked tensor or vector of any type values |
iree_gpu.shuffle_tensor
(GPU::ShuffleTensorOp)link
Shuffles a private tensor across a shared allocation
Syntax:
operation ::= `iree_gpu.shuffle_tensor` $source ``
custom<DynamicIndexList>($offsets, $static_offsets)
custom<DynamicIndexList>($sizes, $static_sizes)
custom<DynamicIndexList>($strides, $static_strides)
`to` $dest $region attr-dict
`:` type($source) `->` type($dest) `->` type($result)
This op is designed to represent a shuffle of private tensor data
collectively held across a set of workers. This operation naturally arises
when combining the regions of producer-consumer scf.forall
operations
that share a mapping type and worker count.
For example, consider the following pair of parallel loops.
%0 = scf.forall (%idy, %idx) in (2, 32) shared_outs(%init = %empty) -> (tensor<4x128xf32>) {
%in = ...
%2 = affine.apply #affine_map<(d0) -> (d0 * 2)> (%idy)
%3 = affine.apply #affine_map<(d0) -> (d0 * 4)> (%idx)
scf.forall.in_parallel {
tensor.parallel_insert_slice %in into %init[%2, %3] [2, 4] [1, 1]
: tensor<2x4xf32> into tensor<4x128xf32>
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
%1 = scf.forall (%idy, %idx) in (8, 8) -> (tensor<128x128xf32>) {
%4 = affine.apply #affine_map<(d0) -> (d0 * 16)> (%idx)
%extracted_slice = tensor.extract_slice %0[0, %4] [4, 16] [1, 1]
: tensor<4x128xf32> to tensor<4x16xf32>
...
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
Because these loops share the same worker type and total count, the bodies of these two loops can be merged with a barrier and a shuffle where the boundary of the loops currently is.
%0 = scf.forall (%idy, %idx) in (8, 8) -> (tensor<4x128xf32>) {
%ids = affine.delinearize_index %idy * 8 + %idx to (2, 32) : index
%in = ...
%2 = affine.apply #affine_map<(d0) -> (d0 * 2)> (%ids#0)
%3 = affine.apply #affine_map<(d0) -> (d0 * 4)> (%ids#1)
%4 = affine.apply #affine_map<(d0) -> (d0 * 16)> (%idx)
%slice = iree_gpu.shuffle_tensor %in[%2, %3] [2, 4] [1, 1] to %empty {
^bb0(%intermediate: tensor<4x128xf32>):
%slice = tensor.extract_slice %intermediate[0, %4] [4, 16] [1, 1] : tensor<4x128xf32> to tensor<4x16xf32>
iree_gpu.yield %slice : tensor<4x16xf32>
} : tensor<2x4xf32> -> tensor<4x128xf32> -> tensor<4x16xf32>
...
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
A shuffle can be lowered to a shared allocation with a write of the source
slice, a barrier, inlining the body of the shuffle op (the read), and then
a barrier to synchronize all workers on the result of the read. Note that
it is undefined behavior if there are any conflicting writes to the
intermediate. Also to execute the barrier, any lowerings of the enclosing
scf.forall
to serial loops is invalid. In other words, the lowerings must
provide the number of workers requested by the loop.
This op takes an input |source| tensor to represent the slice held by this worker before the shuffle, an intermediate tensor |dest| that all workers insert into, and performs a synchronized read from that intermediate tensor.
It is undefined behavior if the source tensor is out of bounds of the intermediate allocation.
Movtivation and Intended Use Cases:
The primary way this op is generated is when fusing parallel loops with tensor results. This operation helps to make lowerings more progressive and flexible. - Rather than lowering straight to vector ops for the reads/writes for the shuffle, this allows separating out the vectorization of the shared memory accesses from earlier tiling steps. - Lowering directly to an alloc + reads and writes breaks the dependency chain making transformations like barrier placement and pipelining potentially more difficult. - Allows the option of non-vector based lowering paths.
Traits: AlwaysSpeculatableImplTrait
, AttrSizedOperandSegments
, SingleBlockImplicitTerminator<mlir::iree_compiler::IREE::GPU::YieldOp>
, SingleBlock
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{}
Attributes:link
Attribute | MLIR Type | Description |
---|---|---|
static_offsets | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
static_sizes | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
static_strides | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
Operands:link
Operand | Description |
---|---|
source |
ranked tensor of any type values |
offsets |
variadic of index |
sizes |
variadic of index |
strides |
variadic of index |
dest |
ranked tensor of any type values |
Results:link
Result | Description |
---|---|
result |
ranked tensor or vector of any type values |
iree_gpu.value_barrier
(GPU::ValueBarrierOp)link
Shuffles a private tensor across a shared allocation
Syntax:
operation ::= `iree_gpu.value_barrier` $input attr-dict `:` type($result)
This operation acts as a barrier on a value semantic SSA value (tensor or vector). It takes a single operand and produces a value equivalent to the input. This does not have copy and/or data movement semantics and simply represents a barrier on all writes in the tensor case, and a barrier until all threads acquire the input vector in the vector case.
This operation is a no-op when not present in a parallel context. This operation is pure as it only requires synchronization for the value it produces.
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{}
Operands:link
Operand | Description |
---|---|
input |
ranked tensor or vector of any type values |
Results:link
Result | Description |
---|---|
result |
ranked tensor or vector of any type values |
iree_gpu.yield
(GPU::YieldOp)link
Yield a value from a region
Syntax:
operation ::= `iree_gpu.yield` $value attr-dict `:` type($value)
This operation is used to yield a single value from a within a region.
Traits: AlwaysSpeculatableImplTrait
, HasParent<::mlir::iree_compiler::IREE::GPU::ShuffleTensorOp>
, ReturnLike
, Terminator
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
, RegionBranchTerminatorOpInterface
Effects: MemoryEffects::Effect{}
Operands:link
Operand | Description |
---|---|
value |
any type |
Attributeslink
ComputeBitwidthsAttrlink
Supported bitwidths for compute
Syntax:
#iree_gpu.compute_bitwidths<
::mlir::iree_compiler::IREE::GPU::ComputeBitwidths # value
>
Enum cases:
* fp64 (FP64
)
* fp32 (FP32
)
* fp16 (FP16
)
* int64 (Int64
)
* int32 (Int32
)
* int16 (Int16
)
* int8 (Int8
)
Parameters:link
Parameter | C++ type | Description |
---|---|---|
value | ::mlir::iree_compiler::IREE::GPU::ComputeBitwidths |
an enum of type ComputeBitwidths |
DerivedThreadConfigAttrlink
drive lowering of an operation by deriving thread distribution when needed.
Syntax: #iree_gpu.derived_thread_config
Lowering config for a single thread tiling level that is inferred after previous (often reduction) levels of tile + fuse. This is intended for fused operations where it is much easier to compute the tile sizes to use after previous levels of tile + fuse, rather than trying to pre-propagate tiling configs.
DotProductOpsAttrlink
Supported dot product ops
Syntax:
#iree_gpu.dotproduct_ops<
::mlir::iree_compiler::IREE::GPU::DotProductOps # value
>
Enum cases:
* none (None
)
* dp4xi8toi32 (DP4xI8ToI32
)
Parameters:link
Parameter | C++ type | Description |
---|---|---|
value | ::mlir::iree_compiler::IREE::GPU::DotProductOps |
an enum of type DotProductOps |
IteratorTypeAttrlink
Iterator type
Syntax:
#iree_gpu.iterator_type<
::mlir::utils::IteratorType # value
>
Enum cases:
* parallel (parallel
)
* reduction (reduction
)
Parameters:link
Parameter | C++ type | Description |
---|---|---|
value | ::mlir::utils::IteratorType |
an enum of type IteratorType |
LaneIdAttrlink
Syntax:
#iree_gpu.lane_id<
int64_t # dim
>
An attribute for mapping scf.forall ops to subgroup lanes.
Parameters:link
Parameter | C++ type | Description |
---|---|---|
dim | int64_t |
LoweringConfigAttrlink
drive lowering of an operation for gpu compilation.
Syntax:
#iree_gpu.lowering_config<
DictionaryAttr # attributes
>
GPU specific implementation of a lowering config. This carries just a dictionary attribute to store any relevant fields. This is the simplest form of a lowering config, offering flexibility at the cost of structure.
Parameters:link
Parameter | C++ type | Description |
---|---|---|
attributes | DictionaryAttr |
The configured fields, including tiling levels |
MMAAttrlink
Attribute describing a particular shape of matrix-multiply and accumulate instruction. Abstractly, all attributes of this type represent the following unit of arithmetic for matrices A, B, and C.
C += A x B
Where the shape of matrix A
is [m, k]
, B
is [k, n]
, and
C
is [m, n]
. This intentionally leaves the layout information abstract
and uses interface methods to materialize layout information only when
needed. The shape of the mma intrinsic is stored explicitly here as that
information is queried frequently.
The element types for this particular mma intrinsic are |aType|, |bType|,
and |cType| for matrices A
, B
, and C
respectively.
link
This mma variant describes configurations for MMA ops. The |intrinsic| field specifies which particular MMA intrinsic this refers to, with each intrinsic implicating a specific MNK shape and operand types. The intrinsic enum name describes these fields as
Where the element type for the A
and B
matrices are both InputType
.
Parameters:link
Parameter | C++ type | Description |
---|---|---|
intrinsic | MMAIntrinsicAttr |
|
mSize | int64_t |
|
nSize | int64_t |
|
kSize | int64_t |
|
aType | ::mlir::Type |
|
bType | ::mlir::Type |
|
cType | ::mlir::Type |
MMAIntrinsicAttrlink
Descriptor for different MMA intrinsics
Syntax:
#iree_gpu.mma_intrinsic<
::mlir::iree_compiler::IREE::GPU::MMAIntrinsic # value
>
Enum cases:
* MFMA_F16_16x16x16_F32 (MFMA_F16_16x16x16_F32
)
* MFMA_F16_32x32x8_F32 (MFMA_F16_32x32x8_F32
)
* WMMA_F16_16x16x16_F32 (WMMA_F16_16x16x16_F32
)
* WMMA_F16_16x16x16_F16 (WMMA_F16_16x16x16_F16
)
Parameters:link
Parameter | C++ type | Description |
---|---|---|
value | ::mlir::iree_compiler::IREE::GPU::MMAIntrinsic |
an enum of type MMAIntrinsic |
MMAOpsArrayAttrlink
Syntax:
#iree_gpu.mma_ops<
::llvm::ArrayRef<MMAAttr> # value
>
Parameters:link
Parameter | C++ type | Description |
---|---|---|
value | ::llvm::ArrayRef<MMAAttr> |
MMAScheduleAttrlink
Syntax:
#iree_gpu.mma_schedule<
::mlir::iree_compiler::IREE::GPU::MmaInterfaceAttr, # intrinsic
int64_t, # subgroup_m_count
int64_t # subgroup_n_count
>
A schedule of MMA intrinsic instruction and various levels of tile sizes to solve a specific contraction problem.
Parameters:link
Parameter | C++ type | Description |
---|---|---|
intrinsic | ::mlir::iree_compiler::IREE::GPU::MmaInterfaceAttr |
|
subgroup_m_count | int64_t |
|
subgroup_n_count | int64_t |
StorageBitwidthsAttrlink
Supported bitwidths for storage
Syntax:
#iree_gpu.storage_bitwidths<
::mlir::iree_compiler::IREE::GPU::StorageBitwidths # value
>
Enum cases:
* b64 (B64
)
* b32 (B32
)
* b16 (B16
)
* b8 (B8
)
Parameters:link
Parameter | C++ type | Description |
---|---|---|
value | ::mlir::iree_compiler::IREE::GPU::StorageBitwidths |
an enum of type StorageBitwidths |
SubgroupOpsAttrlink
Supported subgroup ops
Syntax:
#iree_gpu.subgroup_ops<
::mlir::iree_compiler::IREE::GPU::SubgroupOps # value
>
Enum cases:
* none (None
)
* shuffle (Shuffle
)
* arithmetic (Arithmetic
)
Parameters:link
Parameter | C++ type | Description |
---|---|---|
value | ::mlir::iree_compiler::IREE::GPU::SubgroupOps |
an enum of type SubgroupOps |
TargetAttrlink
Full GPU target attribute
Syntax:
#iree_gpu.target<
::llvm::StringRef, # arch
::llvm::StringRef, # features
TargetWgpAttr, # wgp
TargetChipAttr # chip
>
This attributes describes a full GPU target. It contains a few fields: * The canonical target architecture for compilation, e.g., sm_80 for cuda, gfx942 for hip * A TargetWgpAttr describing the GPU features and limits in a single GPU workgroup processor (WGP), that is, AMD compute unit or NVIDIA streaming multiprocessor * An optional TargetChipAttr describing GPU features for the final chip or product, e.g., wgp count
Parameters:link
Parameter | C++ type | Description |
---|---|---|
arch | ::llvm::StringRef |
target architecture |
features | ::llvm::StringRef |
target features |
wgp | TargetWgpAttr |
|
chip | TargetChipAttr |
TargetChipAttrlink
Chip level target description
Syntax:
#iree_gpu.target_chip<
uint32_t, # wgp_count
DictionaryAttr # extra
>
This attribute contains hardware features/limits at a single GPU chip level. Here a GPU chip means the hardware functionality scope where the whole software compute grid is scheduled onto. A chip typically contains many AMD compute units or NVIDIA streaming multiprocessors; it's the final SKU.
Parameters:link
Parameter | C++ type | Description |
---|---|---|
wgp_count | uint32_t |
|
extra | DictionaryAttr |
TargetWgpAttrlink
Workgroup processor level target description
Syntax:
#iree_gpu.target_wgp<
ComputeBitwidthsAttr, # compute
StorageBitwidthsAttr, # storage
SubgroupOpsAttr, # subgroup
DotProductOpsAttr, # dot
MMAOpsArrayAttr, # mma
DenseI32ArrayAttr, # subgroup_size_choices
DenseI32ArrayAttr, # max_workgroup_sizes
uint32_t, # max_thread_count_per_workgroup
uint32_t, # max_workgroup_memory_bytes
DictionaryAttr # extra
>
This attribute contains hardware features/limits at a single GPU workgroup processor (WGP) level. Here a GPU workgroup processor means the basic hardware functionality unit where a software workgroup is scheduled onto; that is, a compute unit for AMD GPUs or a streaming multiprocessor for NVIDIA GPUs.
Parameters:link
Parameter | C++ type | Description |
---|---|---|
compute | ComputeBitwidthsAttr |
|
storage | StorageBitwidthsAttr |
|
subgroup | SubgroupOpsAttr |
|
dot | DotProductOpsAttr |
|
mma | MMAOpsArrayAttr |
|
subgroup_size_choices | DenseI32ArrayAttr |
|
max_workgroup_sizes | DenseI32ArrayAttr |
|
max_thread_count_per_workgroup | uint32_t |
|
max_workgroup_memory_bytes | uint32_t |
|
extra | DictionaryAttr |