[RFC] LazyScheduler for operator reordering


Under both eager mode and torch.compile, there are use cases to reorder operators (in compile case, across graphs) without model code changes, and existing APIs don’t serve them well.

For OSS models:

  • Some users with important use cases report the need to specify when FSDP’s collectives and PP’s collectives happen (i.e. reordering them manually without torch.compile).
  • For FSDP prefetching, we can potentially use a scheduler to decide when to schedule the prefetch, removing the need for CUDA stream and therefore significantly simplifying the implementation.

For internal models: see Meta internal use cases for LazyScheduler (Meta-internal link).

We’d like to have a flexible API to enable the above and all future cases.

Technical design

High-level ideas

  • We implement operator reordering via the “futures” idea. Specifically, some specific part of the program will have its execution start time decided by a scheduler. Concretely, we slice up the Python program into “segments”, and decide when to run each segment via the scheduler.
  • It’s best if one segment = one single PyTorch nn module method.
    • In eager, each segment can contain multiple PyTorch ops and other non-PyTorch code.
    • In compile, each segment can contain multiple FX graphs and graph breaks.
    • This granularity is chosen because of its relatively low maintenance overhead for the user, and being agnostic to graph breaks.

Potential problems from delaying execution

Let’s assume the original program execution order is [SegmentA, SegmentB], and we want to delay execution of SegmentA so that the actual execution order is [SegmentB, SegmentA]. (Hereinafter we call SegmentA the “DelayedSegment”.)

1) What if SegmentB needs output of DelayedSegment?

  • This can happen because SegmentB is technically after SegmentA in the original program.
  • When this happens, we run DelayedSegment immediately (via materializing its AsyncTensor output), so SegmentB still receives the correct value.

2) How to handle reads or mutations to global variables (e.g. tensor or dict) within DelayedSegment?

  • This is problematic because there can be data dependency between SegmentB and DelayedSegment via those global variables.
  • We detect this via Dynamo tracing DelayedSegment, and enforce that the segment does not read global variables and has no global mutation side-effects. (Note that we do this Dynamo tracing even in eager.)

3) How to handle mutations to input tensors?

  • All mutated tensors in SegmentB should never be read in DelayedSegment. Otherwise, result is incorrect.
    • There is no “on-by-default” way to detect this without adding a lot of per-operator overhead in SegmentB (tensor subclass is too much overhead).
    • Instead, we detect it via “debug mode”: when debug_mode=True, we run SegmentA twice (with same input) in this order: [SegmentA, SegmentB, SegmentA]. We assert that output of each op in SegmentA stays the same. Note this debug mode is expensive and is off-by-default, and should only be used if user observes accuracy issue.
  • All mutated tensors in DelayedSegment should never be read in SegmentB. Otherwise, result is incorrect.
    • Detected via: checking via Dynamo that the mutated tensor is returned as part of the AsyncTensor output from DelayedSegment, marking this AsyncTensor as mutated, and then checking that SegmentB doesn’t read this mutated tensor (via AsyncTensor subclass dispatch).

User-facing API

Suppose we want to run SDD operation after OverArch func1 backward but right before func2 backward, in order to overlap SDD with func2 backward:

class SDDModule(nn.Module):
  def forward(self, x):
    return dist.all_to_all(x, …)

class OverArch(nn.Module):
  def func1(self, x):
    return torch.relu(x)

  def func2(self, x):
    return torch.matmul(x, x)

  def forward(self, x):
    x = self.func1(x)
    x = self.func2(x)
    return x

class Model(nn.Module):
  def __init__(self):
    self.sdd = SDDModule()
    self.overarch = OverArch()

  def __forward__(self, x):
    output = self.overarch(x)
    return output

model = Model()

# Create NN module method to segment mapping.
register_segment(model.sdd.forward, name="sdd_fwd")
register_segment(model.overarch.func2, is_backward=True, name="overarch_func2_bwd")

# Run "sdd_fwd" right before "overarch_func2_bwd".
register_segment_backward_pre_hook("overarch_func2_bwd", "sdd_fwd")

# Run the model as usual.
# LazyScheduler will be used under the hood to control the execution order.
output = model(inputs)


  • Once a segment is registered via register_segment(moduleA.method1, name=XYZ), internally the input and code of moduleA.method1 is packaged in an AsyncFuncHandle (via Module.__getattr__ in eager module, and method name bookkeeping in torch.compile), and moduleA’s __call__ method will invoke that AsyncFuncHandle, which in turn asks the Scheduler to see whether it should be run.
    • In compile case, there will be graph break between each segment, and each graph can be independently decided when to run by the scheduler.
  • Delayed segments always return AsyncTensors as output (which can be materialized on demand). Non-delayed segments always return real tensors.
  • Segment hook registration functions like register_segment_backward_pre_hook() are used to specify the relative execution order. Optionally we can provide an API for explicitly specifying the full schedule (e.g. LazyScheduler(schedule=["sdd_fwd", "overarch_func2_bwd"])).
  • Scheduler looks at the schedule to decide whether to run the current segment. If in the schedule there are other segments before this segment, Scheduler will try to schedule them first and then schedule the current segment; if those preceding segments are not schedule-able yet, Scheduler stores the current segment and doesn’t schedule it.

Why can’t we do X instead of LazyScheduler?

We will go over a few known related approaches:

Approach 1: Implicit pipelining / StreamSyncTensor

  • Ying Liu’s StreamSyncTensor work (also internal link1 link2) is conceptually related to LazyScheduler, but it aims at implementing awaitable behavior (similar to AsyncCollectiveTensor) and doesn’t answer the question of “how to specify when to execute an operation” (which the LazyScheduler is designed for).

Approach 2: Manually change the model code to implement whatever overlap optimization we want

  • This is always possible and has been the approach for many models for a long time. But:
    • As different models need different overlap optimizations, it results in spaghetti code that’s hard to read and maintain.
      • LazyScheduler will help isolate cross-graph overlap related optimizations into a central place, expressed in a very succinct way.
    • It’s very easy to violate assumptions, e.g. user expects A and B to overlap, but turns out user does something in A that prevents the overlap.
      • With LazyScheduler, we can warn/throw error when this kind of violation happens. It also allows user to split A into A1, A2, A3 … and precisely schedule B after A1 but before A2, an optimization that would otherwise be difficult / messy to implement.