You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Data-Tiling introduces tensor encodings on operations. The encodings are set by iree_encoding.set_encoding ops and removed by iree_encoding.unset_encoding ops. Currently, we either fuse set_encoding ops with their producers or form them in their own dispatch. They are usually lowered to relayout ops, like pack/unpack/reshapes/etc. In the LLaMA and SDXL, we identified that it is a sub-optimal solution because we want to propagate encodings across other operations, and fuse them with other producers or hoist them to initializers.
E.g., the Llama2 attention KV cache is a global variable, and is used immediately by a matmul. We want to propagate the packed layout to the KV cache, saving the packing overhead in attention. It was done on an experimental CPU path that materializes the encodings at program level and propagates the pack/unpack ops. [result, prototype] However, the data layout propagation is harder when it happens on encodings. Because we need to bake encoded information in attributes.
There are two missing features for the encoding propagation. One is the ability to propagate encodings. The other is about propagation strategy which is built on top of the former feature. The doc is mainly for the propagation ability that unblocks IREE to explore encoding propagation.
Proposal
The mechanism needs to be reused for whatever encodings. Thus, the proposal is to build the mechanism using interfaces. The core step is moving set_encoding and unset_encoding ops around because all the encodings are introduced and removed by the ops. The propagation happens when we bubble up set_encoding ops and push down unset_encoding ops. See below two examples.
Note that the new encodings can be as the same as the original encodings, especially for element-wise operations. Here we use different encodings in the example because they can be different. It is an implementation detail.
Side note: we can create [Un]SetEncodingLike op interface for set_encoding and unset_encoding ops for plugin requirement, but we likely won't see the case in the near future.
No matter how the encoding propagation is designed, we need to consider how the IR is transformed and what the new encoding is. The encoding attribute is the only one who knows what the new encoded encodings are; the operation is the only one who knows what the transformed ops are. Thus, the proposal is creating a PropagationAttrInterface and PropagationOpInterface in IREE Encoding dialect, and the encoding attribute implements the interface.
PropagationAttrInterface
How to generate the new encodings is attribute specific because only the attribute knows the details of encoded information. E.g., an attribute can encode indexing maps and the indexing maps would change if the given operation is reshape-like op. If the attribute only records reduction dimensions, it returns the same encoding if the reduction dimensions are not changed.
bool isPropagable(Value target): Returns true if the encoding can be propagated across the operation.
FailureOr<PropagationEncoding> generateEncodings(Value target): Returns the new encodings for operand types and result types for the given operation.
PropagationOpInterface
How to transform operations is operation specific because only the operation knows how to handle the encodings. E.g., the tensor.empty ops can return themselves with the given encodings directly because they are shape-like ops. In this context, there are no new set_encoding ops. A linalg op might create as many set_encodings as the number of operands/results and clone itself with the new encoded operands.
For arith.constant, we may transform it to other dialect constants that allow encodings. Or just set encodings on the result of the constant op.
Furthermore, global load/store ops may need to update the globals with encodings and load/store values from/to the encoded globals. It also may want to generate util.initializer ops to handle the transformation from original data to encoded data.
Another advance example is to propagate encodings to function arguments. In this case, we need to locate and update all the function calls.
These hypothesis shows that an op interface is needed.
structPropagationResult {
// A list of operations that are created by the propagation. They are returned// to the caller for further transformation.
SmallVector<Operation *> encodedOps;
// A list of new set_encoding/unset_encoding ops that are generated by// the propagation.
SmallVector<Opreation *> generatedEncodingOps;
// The new corresponding result that is created by the propagation. It is// returned to the caller for further transformation or replacement.
Value replacement;
};
FailureOr<PropagationResult> propagateEncoding(RewriterBase &rewriter, PropagationEncoding encodings, OpResult opResult): Transforms the op with given encodings. It replaces the uses of the op except the opResult, if succeeded. We might want to pass moduleOp and symbalTable to the method because it is a whole-program optimization.
In this case, the generatedEncodingOps contains the new set_encoding and unset_encoding ops. The replacement is the second result of the new linalg.exp op.
The PropagationAttrInterface provides the new encodings, and PropagateOpInterface applies the propagation and returns all the information. It provides all the pieces needed by the propagation. A snippet of the basic implementation can be
Value src = setEncoding.getSource()
auto encoding =
dyn_cast<PropagationAttrInterface>(setEncoding.getEncoding());
auto srcOp =
dyn_cast_or_null<PropagationOpInterface>(src);
if (!encoding || !srcOp || !encoding.isPropagable(src)) {
return;
}
FailureOr<PropagationEncoding> maybeEncodings =
encoding.generateEncodings(src);
if (failed(maybeEncodings)) {
return;
}
FailureOr<PropagationResult> maybeResult =
srcOp.propagateEncoding(rewriter, *maybeEncodings, cast<OpResult>(src));
if (failed(maybeResult)) {
return;
}
worklist.append(maybeResult->generatedEncodingOps);
rewriter.replaceOp(setEncoding, maybeResult->replacement);
Execution Plan
The propagation is currently happening in few patterns. One is happening SetEncoding pass that propagates the encoding across linalg.fill and tensor.empty ops. Another is happening in HoistEncodingOps pass that propagates the encodings across broadcast ops. The other is happening in FuseEncodingOpsIntoDispatchRegions pass that moves the encodings into dispatch ops. We can start with these operations in order, and adapt them to use the interfaces.
The text was updated successfully, but these errors were encountered:
I talked to @pashu123 for next steps. There are two categories of work in my mind.
Port the existing logic to use the interfaces. We can start with bubbleUpSetEncoding that handles the broacast for EncodingAttr. And then try to remove FoldFillWithSetEncoding pattern and leverage it to encoding propagation.
Look at the IR in the models and write down what the propagated encodings in issues. If we are not able to write it down with current semantics, we will discuss and add support.
Furthermore, we need a better encoding propagation setup. Currently, the propagation is modeled by several passes:
Local propagation within dispatch region
Hoist encodings out dispatch region
Encoding propagation
Fuse encodings into dispatch regions
If we finish the two categories tasks, we reach to the v1 that aggressively fuses encodings into dispatch regions.
In v2, we could want to do more propagation until it reaches program boundary or some barriers. This requires analysis, IMO. Furthermore, we will iteratively apply step1 to step4 in this context. A rough idea is implementing the PropagationOpInterface for dispatch region op, and the dispatch region op can do local propagation and hoisting in their implementation.
v2 is just a round idea I have for now, it needs more discussion. For now, let's focus on v1. @pashu123
RFC: Encoding Propagation Interfaces
HackMD version: https://hackmd.io/@hwPnnvLBTB-JGVMeh-bCEA/SkKbbnXiJl
Overview
Data-Tiling introduces tensor encodings on operations. The encodings are set by
iree_encoding.set_encoding
ops and removed byiree_encoding.unset_encoding
ops. Currently, we either fuse set_encoding ops with their producers or form them in their own dispatch. They are usually lowered to relayout ops, like pack/unpack/reshapes/etc. In the LLaMA and SDXL, we identified that it is a sub-optimal solution because we want to propagate encodings across other operations, and fuse them with other producers or hoist them to initializers.E.g., the Llama2 attention KV cache is a global variable, and is used immediately by a matmul. We want to propagate the packed layout to the KV cache, saving the packing overhead in attention. It was done on an experimental CPU path that materializes the encodings at program level and propagates the pack/unpack ops. [result, prototype] However, the data layout propagation is harder when it happens on encodings. Because we need to bake encoded information in attributes.
There are two missing features for the encoding propagation. One is the ability to propagate encodings. The other is about propagation strategy which is built on top of the former feature. The doc is mainly for the propagation ability that unblocks IREE to explore encoding propagation.
Proposal
The mechanism needs to be reused for whatever encodings. Thus, the proposal is to build the mechanism using interfaces. The core step is moving set_encoding and unset_encoding ops around because all the encodings are introduced and removed by the ops. The propagation happens when we bubble up set_encoding ops and push down unset_encoding ops. See below two examples.
Note that the new encodings can be as the same as the original encodings, especially for element-wise operations. Here we use different encodings in the example because they can be different. It is an implementation detail.
Side note: we can create
[Un]SetEncodingLike
op interface for set_encoding and unset_encoding ops for plugin requirement, but we likely won't see the case in the near future.SetEncoding:
UnSetEncoding:
No matter how the encoding propagation is designed, we need to consider how the IR is transformed and what the new encoding is. The encoding attribute is the only one who knows what the new encoded encodings are; the operation is the only one who knows what the transformed ops are. Thus, the proposal is creating a
PropagationAttrInterface
andPropagationOpInterface
in IREE Encoding dialect, and the encoding attribute implements the interface.PropagationAttrInterface
How to generate the new encodings is attribute specific because only the attribute knows the details of encoded information. E.g., an attribute can encode indexing maps and the indexing maps would change if the given operation is reshape-like op. If the attribute only records reduction dimensions, it returns the same encoding if the reduction dimensions are not changed.
bool isPropagable(Value target)
: Returns true if the encoding can be propagated across the operation.FailureOr<PropagationEncoding> generateEncodings(Value target)
: Returns the new encodings for operand types and result types for the given operation.PropagationOpInterface
How to transform operations is operation specific because only the operation knows how to handle the encodings. E.g., the tensor.empty ops can return themselves with the given encodings directly because they are shape-like ops. In this context, there are no new set_encoding ops. A linalg op might create as many set_encodings as the number of operands/results and clone itself with the new encoded operands.
For arith.constant, we may transform it to other dialect constants that allow encodings. Or just set encodings on the result of the constant op.
Furthermore, global load/store ops may need to update the globals with encodings and load/store values from/to the encoded globals. It also may want to generate
util.initializer
ops to handle the transformation from original data to encoded data.Another advance example is to propagate encodings to function arguments. In this case, we need to locate and update all the function calls.
These hypothesis shows that an op interface is needed.
FailureOr<PropagationResult> propagateEncoding(RewriterBase &rewriter, PropagationEncoding encodings, OpResult opResult)
: Transforms the op with given encodings. It replaces the uses of the op except theopResult
, if succeeded. We might want to pass moduleOp and symbalTable to the method because it is a whole-program optimization.E.g.,
In this case, the
generatedEncodingOps
contains the new set_encoding and unset_encoding ops. Thereplacement
is the second result of the newlinalg.exp
op.Note: we will implement it in
compiler/ExternalInterfaces
using the External Model mechanism, like what we've done for HoistableOpInterface.Assemble Pieces
The
PropagationAttrInterface
provides the new encodings, andPropagateOpInterface
applies the propagation and returns all the information. It provides all the pieces needed by the propagation. A snippet of the basic implementation can beExecution Plan
The propagation is currently happening in few patterns. One is happening SetEncoding pass that propagates the encoding across
linalg.fill
andtensor.empty
ops. Another is happening in HoistEncodingOps pass that propagates the encodings across broadcast ops. The other is happening in FuseEncodingOpsIntoDispatchRegions pass that moves the encodings into dispatch ops. We can start with these operations in order, and adapt them to use the interfaces.The text was updated successfully, but these errors were encountered: