Notes on lowering `arith.scaling_extf` and `arith.scaling_truncf` to AMDGPU · Issue #20821 · iree-org/iree · GitHub
More Web Proxy on the site http://driver.im/
You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Normalize to splat form:
a. If the op's operating on scalars or 0-D vectors, go to 1xT
b. Look at the scale input: if it isn't a broadcast or a splat - that is, if there isn't a single scalar value in all elements - unroll along leading or trailing dimensions (dependeng on what any vector.broadcast feeding the scale is doing) until you have ops of that form
Now you've got scaling_{ext,trunc} of the form %out = scaling_op %in, splat(%scale) : vector<...xT>, vector<...xU>, vector<...xV>
Once we're in this form, the scale is a scalar from now on
Then, determine the instruction-level block size of the intrinsic (so 32 unless we'll be targetting those 2x16xf32 intrinsics, in which case you go for 16)
Start extracting out [blocksize] elements at a time, padding out with 0s if needed, and pass those to an amdgpu.cvt_packed_scale type thing
a. While there, the conversion from f8E8M0 to f32 gets to be a special case, since, unlike in the general case, we can just shift left and don't have to check for NaN
The text was updated successfully, but these errors were encountered:
... On further review, it looks like there are also scalar - or quasi scalar - round-down and round-up instructions that got added while I wasn't looking, so a lot of this gets easier
From me sketching this out in Umang's DMs
a. If the op's operating on scalars or 0-D vectors, go to 1xT
b. Look at the scale input: if it isn't a broadcast or a splat - that is, if there isn't a single scalar value in all elements - unroll along leading or trailing dimensions (dependeng on what any vector.broadcast feeding the scale is doing) until you have ops of that form
%out = scaling_op %in, splat(%scale) : vector<...xT>, vector<...xU>, vector<...xV>
a. While there, the conversion from f8E8M0 to f32 gets to be a special case, since, unlike in the general case, we can just shift left and don't have to check for NaN
The text was updated successfully, but these errors were encountered: