Skip to content

Commit

Permalink
Schedule loop domains such that reshape transforms are cancelled
Browse files Browse the repository at this point in the history
  • Loading branch information
naoyam committed Jan 8, 2025
1 parent 9ce2112 commit c85ea88
Show file tree
Hide file tree
Showing 3 changed files with 509 additions and 14 deletions.
195 changes: 190 additions & 5 deletions csrc/scheduler/tools/loop_domain_scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <id_model/id_model.h>
#include <id_model/schedule.h>
#include <ir/internal_nodes.h>
#include <ir/utils.h>
#include <scheduler/tools/loop_domain_scheduler.h>
#include <val_graph_visitor.h>

Expand Down Expand Up @@ -323,6 +324,26 @@ ValGraphBFS::ExprPath LoopDomainScheduler::getReplayPath(TensorView* tv) const {
//
// In the case of the update mode, the target should be just the
// current loop domain of the tensor.
//
// TODO: Reconsider if ignoring broadcast IDs is the right thing to
// do. There can be legitimate broadcast transformations in the
// reference, and if there's a matching broadcast ID in this tv, the
// transformations should be propagated. However, when a concrete ID
// of a reference tv replaces a broadcast ID of this tv, there's no
// path from the concrete ID to the broadcast ID, thus getting a
// path would fail. I think the fundamental problem is the
// disconnection to the broadcast ID. This could be avoided if the
// Broadcast graph was used. However, for scheduling loop domains,
// the Broadcast graph isn't the right choice when a corresponding
// concrete ID is resized. For example, if a concrete
// ID is resized to a broadcast ID, how should the corresponding
// broadcast ID be resized? If the same resize were applied, the
// extent of the resized broadcast ID would become negative, which
// is probably not what we would want. Perhaps, we should consider
// expanding the broadcast ID to the concrete size and only map
// expanded broadcast IDs with concrete IDs. And if the expand is
// represented with an IterDomain expression, we could avoid
// disconnected IDs.
ValGroups tv_target_domains = graph().toGroups(TensorDomain::noBroadcasts(
update_loop_domain_only_ ? tv->getLoopDomain()
: tv->getMaybeRootDomain()));
Expand All @@ -345,6 +366,17 @@ ValGraphBFS::ExprPath LoopDomainScheduler::getReplayPath(TensorView* tv) const {
.first;
}

if (update_loop_domain_only_) {
std::cerr << "TV: " << tv->toString() << "\n";
std::cerr << "All ancestors: " << nvfuser::toString(all_ancestors_of_ref_)
<< "\n";
for (const auto& tv_target_id : tv_target_domains) {
if (!all_ancestors_of_ref_.has(tv_target_id)) {
std::cerr << "Not found: " << nvfuser::toString(tv_target_id) << "\n";
}
}
}

// In the case of the update mode, the path from the reference is
// assumed to just a backward traversal path.
NVF_ERROR(
Expand Down Expand Up @@ -407,7 +439,8 @@ void scheduleLoopDomainsLike(

void scheduleLoopDomainsBy(
const std::vector<TensorView*>& tvs,
Expr* transform) {
Expr* transform,
Direction replay_dir) {
Fusion* fusion = transform->fusion();
IdModel id_model(fusion, /*build_graphs=*/false);
const ValGraph& exact_graph = id_model.buildExactGraph();
Expand Down Expand Up @@ -439,15 +472,16 @@ void scheduleLoopDomainsBy(
}
}

Direction replay_dir = Direction::Undefined;

// It should be either: all of the inputs found and none of the
// outputs found, or none of the inputs found and all of the
// outputs found.
if (input_ids.size() == transform->inputs().size()) {
if (replay_dir != Direction::Backward &&
input_ids.size() == transform->inputs().size()) {
NVF_ERROR(output_ids.empty());
replay_dir = Direction::Forward;
} else if (output_ids.size() == transform->outputs().size()) {
} else if (
replay_dir != Direction::Forward &&
output_ids.size() == transform->outputs().size()) {
NVF_ERROR(input_ids.empty());
replay_dir = Direction::Backward;
} else {
Expand Down Expand Up @@ -500,5 +534,156 @@ void scheduleLoopDomainsBy(
return;
}

void cancelReshapeInLoopDomains(TensorView* from_tv) {
Fusion* fusion = from_tv->fusion();
IdModel id_model(fusion, /*build_graphs=*/false);
id_model.buildExactGraph();
const auto& exact_graph = id_model.idGraph(IdMappingMode::EXACT);

// Reshapes producing these IDs should not be cancelled
ValGroups reshape_dependent_ids;
for (const ExprGroup& expr_g :
exact_graph.disjointExprSets().disjointSets()) {
if (expr_g->front()->isA<Resize>()) {
reshape_dependent_ids.pushBack(exact_graph.inputGroups(expr_g));
}
}

for (const ValGroup& val_g : exact_graph.disjointValSets().disjointSets()) {
if (std::any_of(val_g->begin(), val_g->end(), [](Val* val) {
NVF_ERROR(val->isA<IterDomain>());
return val->as<IterDomain>()->isReduction();
})) {
reshape_dependent_ids.pushBack(val_g);
}
}

auto all_dep_exprs_from_tv =
DependencyCheck::getAllExprsBetween({from_tv}, fusion->outputs());

// Visit all reshapes in a reverse topological order
for (auto exprs_it = all_dep_exprs_from_tv.rbegin();
exprs_it != all_dep_exprs_from_tv.rend();
++exprs_it) {
auto reshape = dynamic_cast<ViewOp*>(*exprs_it);
if (reshape == nullptr) {
continue;
}

auto reshape_out = reshape->out();

auto all_dep_vals =
DependencyCheck::getAllValsBetween({reshape_out}, fusion->outputs());
// Exclude reshape_out
all_dep_vals.erase(all_dep_vals.begin());
auto all_dep_tvs = ir_utils::filterByType<TensorView>(all_dep_vals);

// Find logical IDs that do not exist in the root domain. They are
// the new IDs that are produced by this reshape op. If a logical
// ID is already found in the root domain, there's nothing to do
// for it.
std::vector<IterDomain*> new_logical_ids;
for (const auto& logical_id : reshape_out->getLogicalDomain()) {
if (!reshape_out->domain()->isRoot(logical_id)) {
new_logical_ids.push_back(logical_id);
}
}

if (new_logical_ids.empty()) {
// Nothing to do with a no-op reshape. This may not happen.
continue;
}

// Find logical IDs that do not need to exist in the loop domain
std::unordered_set<Val*> cancellable_ids;
for (const auto new_logical_id : new_logical_ids) {
auto new_id_group = exact_graph.toGroup(new_logical_id);
// Not cancellable if used by resize or reduced.
auto reachable_exprs = getReachableNodesFrom<ValGraphPermissiveBFS>(
{new_id_group},
{reshape_dependent_ids.begin(), reshape_dependent_ids.end()},
Direction::Forward,
exact_graph);
if (!reachable_exprs.empty()) {
continue;
}

cancellable_ids.insert(new_logical_id);
}

if (cancellable_ids.empty()) {
continue;
}

// Update the loop domain by each of the reshape exprs in a
// reverse topological order.
auto reshape_exprs = DependencyCheck::getAllExprsBetween(
{reshape_out->getRootDomain().begin(),
reshape_out->getRootDomain().end()},
{reshape_out->getLogicalDomain().begin(),
reshape_out->getLogicalDomain().end()});

auto reshape_out_loop_domain = reshape_out->getLoopDomain();

for (auto reshape_exprs_it = reshape_exprs.rbegin();
reshape_exprs_it != reshape_exprs.rend();
++reshape_exprs_it) {
auto reshape_expr = *reshape_exprs_it;

// If any of the output IDs of reshape_expr is not found in
// cancellable_ids, that means the expr cannot be cancelled.
if (std::any_of(
reshape_expr->outputs().begin(),
reshape_expr->outputs().end(),
[&](Val* reshape_expr_out) -> bool {
return !cancellable_ids.count(reshape_expr_out);
})) {
continue;
}

// Update all of the dependent TVs by this reshape expr
scheduleLoopDomainsBy(
all_dep_tvs.vector(), reshape_expr, Direction::Backward);

cancellable_ids.insert(
reshape_expr->inputs().begin(), reshape_expr->inputs().end());

// For the reshape output tensor itself, since it already has the
// reshape expr, it just needs
// tv->setLoopDomain(tv->getRootDomain()). However, since some of the
// reshape exprs may not be cancellable, update a vector of the
// loop IDs for each of the cancelled exprs individually and use
// it to set the loop domain of the reshape output tensor

// Insert the input IDs to the loop domain
auto insert_pos = std::find(
reshape_out_loop_domain.begin(),
reshape_out_loop_domain.end(),
reshape_expr->outputs().front());
NVF_ERROR(insert_pos != reshape_out_loop_domain.end());
for (auto inp : reshape_expr->inputs()) {
insert_pos =
reshape_out_loop_domain.insert(insert_pos, inp->as<IterDomain>());
++insert_pos;
}

// Remove the output IDs
reshape_out_loop_domain.erase(
std::remove_if(
reshape_out_loop_domain.begin(),
reshape_out_loop_domain.end(),
[&](IterDomain* cur_loop_id) {
return std::find(
reshape_expr->outputs().begin(),
reshape_expr->outputs().end(),
cur_loop_id) != reshape_expr->outputs().end();
}),
reshape_out_loop_domain.end());
}

reshape_out->setLoopDomain(reshape_out_loop_domain);
}
}

} // namespace scheduler_tools
} // namespace nvfuser
63 changes: 54 additions & 9 deletions csrc/scheduler/tools/loop_domain_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,17 @@
// clang-format on
#pragma once

#include <bfs.h>

#include <vector>

namespace nvfuser {

class Expr;
class Fusion;
class TensorView;
class IterDomain;
class ViewOp;

namespace scheduler_tools {

Expand All @@ -30,14 +34,14 @@ void scheduleLoopDomainsLike(
bool update_loop_domain_only = false);

// Replay a transform expr on the loop domain of each of the given
// tensors. If the input of the transform is exact mapped with the loop
// domain, the transform is replayed as a forward op. If the output
// is exact mapped with the loop domain, it's replayed as a backward
// op. The loop domain of each tensor is updated with the replayed
// transform expr. If it's replayed as a forward op, the outputs
// replace the inputs in the loop domain. If it's replayed as a
// backward op, the inputs replace the outputs in the loop domain. The
// new IDs are inserted at the outermost position of the input IDs.
// tensors. If the replay direction is specified, the expr is replayed
// as specified. Otherwise, if the input of the transform is exact mapped with
// the loop domain, the transform is replayed as a forward op. If the output is
// exact mapped with the loop domain, it's replayed as a backward op. The loop
// domain of each tensor is updated with the replayed transform expr. If it's
// replayed as a forward op, the outputs replace the inputs in the loop domain.
// If it's replayed as a backward op, the inputs replace the outputs in the loop
// domain. The new IDs are inserted at the outermost position of the input IDs.
//
// For example, suppose a fusion has:
//
Expand All @@ -62,7 +66,48 @@ void scheduleLoopDomainsLike(
// LoopDomainSchedulingTest.ScheduleLoopDomainsBy1 for more examples.
void scheduleLoopDomainsBy(
const std::vector<TensorView*>& tvs,
Expr* transform);
Expr* transform,
Direction replay_dir = Direction::Undefined);

// For each of immediate and indirect consumer tensors of from_tv,
// schedule its loop domain such that reshape transforms appearing
// between the tensor and from_tv are cancelled. For example, suppose
// a fusion has:
//
// t0 = makeSymbolicTensor(3); // [i0, i1, i2]
// t1 = permute(t0, {1, 0, 2}); // [i1, i0, i2]
// t2 = reshape(t1, {i1, i0*i2}); // [i1, i0*i2]
// t3 = sin(t2) // [i1, i0*i2]
//
// In this case, cancelReshapeInLoopDomains(t0) would affect t2 and t3
// as follows:
//
// t2:
// root: [i1, i0*i2] (unchanged)
// logical: [i1, i0*i2] (unchanged)
// loop: [i1, i0, i2]
//
// t3:
// logical: [i1, i0*i2] (unchanged)
// loop: [i1, i0, i2]
//
// t1 would not be changed at all as there's no reshape between t0 and
// t1.
//
// This scheduling could help optimize memory accesses to
// fusion inputs. In the above case, we could then reorder the loop
// domains of t1, t2 and t3 as [i0, i1, i2], i.e., the same ordering
// as t0, which could minimize strided accesses.
//
// This scheduling is not always feasible. Specifically, if a reshape
// outout iter domain is resized, the loop domain needs to keep using
// the reshape output iter domain. Similary, if a rehape output iter
// domain is reduced, the reshape is currently not cancelled. This is
// because if a reshape has a split and only one of the split output
// iter domain is reduced, the split needs to remain. If a reshape
// only consists of merge transforms, cancellation should be possible,
// but that is not currently supported.
void cancelReshapeInLoopDomains(TensorView* from_tv);

} // namespace scheduler_tools
} // namespace nvfuser
Loading

0 comments on commit c85ea88

Please sign in to comment.