8000 RFC: Encoding Propagation Interfaces · Issue #20179 · iree-org/iree · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

RFC: Encoding Propagation Interfaces #20179

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
hanhanW opened this issue Mar 6, 2025 · 2 comments
Open

RFC: Encoding Propagation Interfaces #20179

hanhanW opened this issue Mar 6, 2025 · 2 comments
Assignees
Labels
codegen Shared code generation infrastructure and dialects

Comments

@hanhanW
Copy link
Contributor
hanhanW commented Mar 6, 2025

RFC: Encoding Propagation Interfaces

HackMD version: https://hackmd.io/@hwPnnvLBTB-JGVMeh-bCEA/SkKbbnXiJl

Overview

Data-Tiling introduces tensor encodings on operations. The encodings are set by iree_encoding.set_encoding ops and removed by iree_encoding.unset_encoding ops. Currently, we either fuse set_encoding ops with their producers or form them in their own dispatch. They are usually lowered to relayout ops, like pack/unpack/reshapes/etc. In the LLaMA and SDXL, we identified that it is a sub-optimal solution because we want to propagate encodings across other operations, and fuse them with other producers or hoist them to initializers.

E.g., the Llama2 attention KV cache is a global variable, and is used immediately by a matmul. We want to propagate the packed layout to the KV cache, saving the packing overhead in attention. It was done on an experimental CPU path that materializes the encodings at program level and propagates the pack/unpack ops. [result, prototype] However, the data layout propagation is harder when it happens on encodings. Because we need to bake encoded information in attributes.

There are two missing features for the encoding propagation. One is the ability to propagate encodings. The other is about propagation strategy which is built on top of the former feature. The doc is mainly for the propagation ability that unblocks IREE to explore encoding propagation.

Proposal

The mechanism needs to be reused for whatever encodings. Thus, the proposal is to build the mechanism using interfaces. The core step is moving set_encoding and unset_encoding ops around because all the encodings are introduced and removed by the ops. The propagation happens when we bubble up set_encoding ops and push down unset_encoding ops. See below two examples.

Note that the new encodings can be as the same as the original encodings, especially for element-wise operations. Here we use different encodings in the example because they can be different. It is an implementation detail.

Side note: we can create [Un]SetEncodingLike op interface for set_encoding and unset_encoding ops for plugin requirement, but we likely won't see the case in the near future.

SetEncoding:

%init = tensor.empty() : tensor<?xf32>
%0 = linalg.exp
  ins(%src : tensor<?xf32>)
  outs(%init : tensor<?xf32>)
%1 = iree_encoding.set_encoding %0
  : tensor<?xf32> -> tensor<?xf32, #encoding>

-->

%init = tensor.empty() : tensor<?xf32>
%encoded_src = iree_encoding.set_encoding %src
  : tensor<?xf32>  -> tensor<?xf32, #new_encoding0>
%encoded_init = iree_encoding.set_encoding %init
  : tensor<?xf32>  -> tensor<?xf32, #new_encoding1>
%1 = linalg.exp
  ins(%encoded_src : tensor<?xf32, #new_encoding0>)
  outs(%encoded_init : tensor<?xf32, #new_encoding1>)

UnSetEncoding:

%0 = iree_encoding.unset_encoding %src
  : tensor<?xf32, #encoding> -> tensor<?xf32>
%init = tensor.empty() : tensor<?xf32>
%1 = linalg.exp
  ins(%0 : tensor<?xf32>)
  outs(%init : tensor<?xf32>)

-->

%encoded_init = iree_encoding.set_encoding %init
  : tensor<?xf32> -> tensor<?xf32, #new_encoding>
%0 = linalg.exp
  ins(%src: tensor<?xf32, #encoding>)
  outs(%encoded_init: tensor<?xf32, #new_encoding>)
%1 = iree.unset_encoding %0
  : tensor<?xf32, #new_encoding> -> tensor<?xf32>

No matter how the encoding propagation is designed, we need to consider how the IR is transformed and what the new encoding is. The encoding attribute is the only one who knows what the new encoded encodings are; the operation is the only one who knows what the transformed ops are. Thus, the proposal is creating a PropagationAttrInterface and PropagationOpInterface in IREE Encoding dialect, and the encoding attribute implements the interface.

PropagationAttrInterface

How to generate the new encodings is attribute specific because only the attribute knows the details of encoded information. E.g., an attribute can encode indexing maps and the indexing maps would change if the given operation is reshape-like op. If the attribute only records reduction dimensions, it returns the same encoding if the reduction dimensions are not changed.

struct PropagationEncoding {
    SmallVector<Attribute> operandEncodings;
    SmallVector<Attribute> resultEnodings;
};
  • bool isPropagable(Value target): Returns true if the encoding can be propagated across the operation.
  • FailureOr<PropagationEncoding> generateEncodings(Value target): Returns the new encodings for operand types and result types for the given operation.

PropagationOpInterface

How to transform operations is operation specific because only the operation knows how to handle the encodings. E.g., the tensor.empty ops can return themselves with the given encodings directly because they are shape-like ops. In this context, there are no new set_encoding ops. A linalg op might create as many set_encodings as the number of operands/results and clone itself with the new encoded operands.

For arith.constant, we may transform it to other dialect constants that allow encodings. Or just set encodings on the result of the constant op.

Furthermore, global load/store ops may need to update the globals with encodings and load/store values from/to the encoded globals. It also may want to generate util.initializer ops to handle the transformation from original data to encoded data.

Another advance example is to propagate encodings to function arguments. In this case, we need to locate and update all the function calls.

These hypothesis shows that an op interface is needed.

struct PropagationResult {
    // A list of operations that are created by the propagation. They are returned
    // to the caller for further transformation.
    SmallVector<Operation *> encodedOps;
    
    // A list of new set_encoding/unset_encoding ops that are generated by
    // the propagation.
    SmallVector<Opreation *> generatedEncodingOps;
    
    // The new corresponding result that is created by the propagation. It is
    // returned to the caller for further transformation or replacement.
    Value replacement;
};
  • FailureOr<PropagationResult> propagateEncoding(RewriterBase &rewriter, PropagationEncoding encodings, OpResult opResult): Transforms the op with given encodings. It replaces the uses of the op except the opResult, if succeeded. We might want to pass moduleOp and symbalTable to the method because it is a whole-program optimization.

E.g.,

%0:2 = linalg.multi_result
   ins(%src: tensor<?xf32>)
   outs(%init_0, %init_1: tensor<?xf32>, tensor<?xf32>)
%1 = iree_encoding.set_encoding %0#1
  : tensor<?xf32> -> tensor<?xf32, #encoding>
  
-- can be transformed to -->

%init = tensor.empty() : tensor<?xf32>
%encoded_src = iree_encoding.set_encoding %src
  : tensor<?xf32>  -> tensor<?xf32, #new_encoding0>
%encoded_init = iree_encoding.set_encoding %init
  : tensor<?xf32>  -> tensor<?xf32, #new_encoding1>
%1:2 = linalg.multi_result
  ins(%encoded_src : tensor<?xf32, #new_encoding0>)
  outs(%encoded_init_0, %encoded_init_1)
  -> (tensor<?xf32, #new_encoding1> : tensor<?xf32, #new_encoding2>)
%2 = iree_encoding.unset_encoding %1#0
  : tensor<?xf32, #new_encoding1> -> tensor<?xf32>

In this case, the generatedEncodingOps contains the new set_encoding and unset_encoding ops. The replacement is the second result of the new linalg.exp op.

Note: we will implement it in compiler/ExternalInterfaces using the External Model mechanism, like what we've done for HoistableOpInterface.

Assemble Pieces

The PropagationAttrInterface provides the new encodings, and PropagateOpInterface applies the propagation and returns all the information. It provides all the pieces needed by the propagation. A snippet of the basic implementation can be

Value src = setEncoding.getSource()
auto encoding =
    dyn_cast<PropagationAttrInterface>(setEncoding.getEncoding());
auto srcOp =
    dyn_cast_or_null<PropagationOpInterface>(src);
if (!encoding || !srcOp || !encoding.isPropagable(src)) {
    return;
}

FailureOr<PropagationEncoding> maybeEncodings =
    encoding.generateEncodings(src);
if (failed(maybeEncodings)) {
    return;
}

FailureOr<PropagationResult> maybeResult =
    srcOp.propagateEncoding(rewriter, *maybeEncodings, cast<OpResult>(src));
if (failed(maybeResult)) {
    return;
}

worklist.append(maybeResult->generatedEncodingOps);
rewriter.replaceOp(setEncoding, maybeResult->replacement);

Execution Plan

The propagation is currently happening in few patterns. One is happening SetEncoding pass that propagates the encoding across linalg.fill and tensor.empty ops. Another is happening in HoistEncodingOps pass that propagates the encodings across broadcast ops. The other is happening in FuseEncodingOpsIntoDispatchRegions pass that moves the encodings into dispatch ops. We can start with these operations in order, and adapt them to use the interfaces.

@hanhanW
Copy link
Contributor Author
hanhanW commented Apr 30, 2025

#20567 adds the interfaces and implement it for matmul_k encoding propagation.

@hanhanW
Copy link
Contributor Author
hanhanW commented Apr 30, 2025

I talked to @pashu123 for next steps. There are two categories of work in my mind.

  1. Port the existing logic to use the interfaces. We can start with bubbleUpSetEncoding that handles the broacast for EncodingAttr. And then try to remove FoldFillWithSetEncoding pattern and leverage it to encoding propagation.
  2. Look at the IR in the models and write down what the propagated encodings in issues. If we are not able to write it down with current semantics, we will discuss and add support.

Furthermore, we need a better encoding propagation setup. Currently, the propagation is modeled by several passes:

  1. Local propagation within dispatch region
  2. Hoist encodings out dispatch region
  3. Encoding propagation
  4. Fuse encodings into dispatch regions

If we finish the two categories tasks, we reach to the v1 that aggressively fuses encodings into dispatch regions.

In v2, we could want to do more propagation until it reaches program boundary or some barriers. This requires analysis, IMO. Furthermore, we will iteratively apply step1 to step4 in this context. A rough idea is implementing the PropagationOpInterface for dispatch region op, and the dispatch region op can do local propagation and hoisting in their implementation.

v2 is just a round idea I have for now, it needs more discussion. For now, let's focus on v1. @pashu123

cc @MaheshRavishankar @Max191

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
codegen Shared code generation infrastructure and dialects
Projects
None yet
Development

No branches or pull requests

2 participants
0