-
Notifications
You must be signed in to change notification settings - Fork 113
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
tablegen: Add StaticSelect to select based on static condition #2206
Conversation
string value = val; | ||
} | ||
|
||
class StaticIf<bit uses_primal, bit uses_shadow, string condition_> : Operation<uses_primal, uses_shadow> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think probably we should remove the explicit needs primal and shadow, and instead directly integrate this into the use analysis (lookup SelectIfActive in this case we wouldn’t do the check on if active but check on whatever the string is)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking into it, I could not find the custom use analysis for SelectIfActive, it seems that instead SelectIfActive has no analysis:
SmallVector<bool> cachedArguments(Operation *op,
MGradientUtilsReverse *gutils) const {
SmallVector<bool> toret(op->getNumOperands(), false);
for (size_t idx=0; idx<op->getNumOperands(); idx++) {
bool used = false;
// Rule (Op ?:$x, ?:$y)
// Arg 0 : (SelectIfActive ?:$x, (CMul (DiffeRet), ?:$y), ?:$x)
// Arg 1 : (DiffeRet)
toret[idx] = used;
}
return toret;
}
And it currently works because either it is used in the Forward diff which seem to rely on the analysis or the primals are used outside of the SelectIfActive node.
It seems that the assert here does not trigger due to the way enzyme-tblgen might be built?
assert(!usesCustom); |
So really, it should base its use on its arguments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah yeah we should add support there (actually for both static if and if active). The reason we never hit that branch is because use analysis is auto generated from reverse mode rules and I think we only use select if active in forward.
However here we really want to add the correct diff use analysis, since then we can for example use static if to define rules for reduce and static if it’s a known max min or add etc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But also it should do so recursively for the operands of the select if active (depending on if active)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, the same analysis need to be used for SelectIfComplex as well (or it should be replaced with an instance of StaticIf)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean ideally we can define selectifcomplex in tablegen with selectif
I am trying to implement SelectIfActive based on StaticSelect as well but it looks like it has special handling for vectorization as well: Enzyme/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp Lines 484 to 498 in 5b330a9
I guess I should move this to StaticSelect as well |
ah yeah that one might be hard. Yeah the return of handle is whether or not the value is batched already (e.g. for batched vector mode). The right thing to do is probably as follows. Emit both the lhs and the rhs. If they have the same vector mode set (either both not vectored or both vectored), that's fine and return it. If one has it set, upgrade the other one to vector mode (e.g. do that insertvalue related stuff), and return vector = true. |
cc @jumerckx this will probably also need an MLIR version too at one point (which presumably calls the new broadcast op) |
(void)usesCustom; | ||
assert(!usesCustom); | ||
// This only concerns instances of StaticSelect for now | ||
if (usesCustom) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very minor but can we explicitly check opName == "StaticSelect" || Def->isSubClassOf("StaticSelect") and leave this assertion as is?
(void)usesCustom; | ||
assert(!usesCustom); | ||
|
||
if (Def->isSubClassOf("StaticSelect")) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these should be in the other order the assertion come after the handler, and the handler should also return to ensure the error isn't hit
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah right, I moved the assert after and set StaticSelect to have useCustom
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I realized a bug in the useCustom analysis logic that was easier to just code up so I went ahead and fixed it
cc @mofeing, we can collaborate to integrate in your fft pr.
I am open to other naming ideas.