From 929ed00c6b17de3bb9fa0146af022e742bf12ecd Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Wed, 26 Nov 2025 05:07:10 -0800 Subject: [PATCH 01/15] Apply swap_map to bsym before any reasoning --- thunder/core/update_aliases.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/thunder/core/update_aliases.py b/thunder/core/update_aliases.py index de8ec89604..0b2b46b05e 100644 --- a/thunder/core/update_aliases.py +++ b/thunder/core/update_aliases.py @@ -167,6 +167,7 @@ def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[li # Third pass: insert alias updates for bsym in computation_trace.bound_symbols: if _is_inplace_op(bsym) or _is_view_creation_op(bsym) or _involves_viewed_args(bsym, viewed): + bsym = bsym.from_bsym_swap_proxies(swap_map, skip_output=True) in_tensors = list(map(variableify, filter(lambda p: isinstance(p, TensorProxy), bsym.flat_proxy_args))) if _is_inplace_op(bsym) and in_tensors: in_tensors = {in_tensors[0]} @@ -177,7 +178,7 @@ def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[li group = set(reduce(set.union, filter(lambda g: any(g.intersection(in_tensors)), view_groups), set())) if not group or not (views_encountered := group.intersection(encountered)): # If group is empty, this is a view creation with operands that are not involved in any inplace ops. - bsyms.append(bsym.from_bsym_swap_proxies(swap_map, skip_output=True)) + bsyms.append(bsym) continue new_aliases = _get_new_aliases(views_encountered, computation_trace) From 86adceb8eeca50b153652fae8dd93145a6be5f7c Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Wed, 26 Nov 2025 08:24:48 -0800 Subject: [PATCH 02/15] Keep fusion break before in-place --- thunder/core/update_aliases.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/thunder/core/update_aliases.py b/thunder/core/update_aliases.py index 0b2b46b05e..74fef80400 100644 --- a/thunder/core/update_aliases.py +++ b/thunder/core/update_aliases.py @@ -176,8 +176,15 @@ def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[li out_tensors = set(map(variableify, filter(lambda p: isinstance(p, TensorProxy), bsym.flat_proxy_outs))) encountered.update(in_tensors) group = set(reduce(set.union, filter(lambda g: any(g.intersection(in_tensors)), view_groups), set())) - if not group or not (views_encountered := group.intersection(encountered)): - # If group is empty, this is a view creation with operands that are not involved in any inplace ops. + views_encountered = group.intersection(encountered) + + if _is_inplace_op(bsym): + # Super-hacky workaround to insert fusion break because nvFuser doesn't support mutation on intermediates + # See https://github.com/Lightning-AI/lightning-thunder/issues/2768#issuecomment-3581908434 + views_encountered = in_tensors + + if not views_encountered: + # This is a view creation with operands that are not involved in any inplace ops. bsyms.append(bsym) continue From 92663e5348b184d4966f1af0bf4ded232e394158 Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Wed, 26 Nov 2025 08:58:10 -0800 Subject: [PATCH 03/15] Add test --- thunder/tests/test_update_aliases.py | 30 ++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/thunder/tests/test_update_aliases.py b/thunder/tests/test_update_aliases.py index 4d293aae21..7cafcbec0f 100644 --- a/thunder/tests/test_update_aliases.py +++ b/thunder/tests/test_update_aliases.py @@ -542,3 +542,33 @@ def f(x, y, z): torch.testing.assert_close(a, a_) torch.testing.assert_close(b, b_) torch.testing.assert_close(c, c_) + + +@instantiate( + dtypes=(dtypes.float32,), +) +def test_update_aliases_count(executor, device, dtype): + def f(x): + x.sin_() + return x * x * x * x + + def g(x): + x.sin_() + x.cos_() + return x * x * x * x + + expected_num_update_aliases = { + f: 1, # before sin_ + g: 2, # before sin_ and cos_; latter is a hack to cause fusion break + } + + for fn in [f, g]: + a = make_tensor((2, 3), dtype=dtypes.to_torch_dtype(dtype), device=device) + a_ = a.clone().detach() + jfn = executor.make_callable(fn) + actual = jfn(a) + expected = fn(a_) + torch.testing.assert_close(actual, expected) + extrace = thunder.last_traces(jfn)[-1] + actual_num_update_aliases = len([bsym for bsym in extrace.bound_symbols if bsym.sym.name == "update_aliases"]) + assert actual_num_update_aliases == expected_num_update_aliases[fn] From c5d914d1e7301bd9c0f1532beae734d36d7630f9 Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Wed, 26 Nov 2025 09:57:35 -0800 Subject: [PATCH 04/15] Cosmetic change --- thunder/core/update_aliases.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/thunder/core/update_aliases.py b/thunder/core/update_aliases.py index 74fef80400..49dd81b3cd 100644 --- a/thunder/core/update_aliases.py +++ b/thunder/core/update_aliases.py @@ -179,9 +179,8 @@ def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[li views_encountered = group.intersection(encountered) if _is_inplace_op(bsym): - # Super-hacky workaround to insert fusion break because nvFuser doesn't support mutation on intermediates - # See https://github.com/Lightning-AI/lightning-thunder/issues/2768#issuecomment-3581908434 - views_encountered = in_tensors + # This is a hack to insert fusion break because nvFuser doesn't support mutation on intermediates + views_encountered.update(in_tensors) if not views_encountered: # This is a view creation with operands that are not involved in any inplace ops. From 3954b513a51821bed91f22c4a346473d2c54614d Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Fri, 5 Dec 2025 09:12:07 -0800 Subject: [PATCH 05/15] Remove no longer needed xfail --- thunder/tests/test_update_aliases.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/thunder/tests/test_update_aliases.py b/thunder/tests/test_update_aliases.py index 7cafcbec0f..3e8de8b344 100644 --- a/thunder/tests/test_update_aliases.py +++ b/thunder/tests/test_update_aliases.py @@ -19,7 +19,6 @@ NOTHING, TorchExecutor, TorchCompileExecutor, - nvFuserExecutor, requiresCUDA, ) from thunder.torch import _torch_to_thunder_function_map, _inplace_to_out_of_place @@ -354,9 +353,6 @@ def f(x, y, z): decorators=(pytest.mark.parametrize("cache", ("constant values", "symbolic values")),), ) def test_write_to_intermediate_result(executor, device, dtype, cache): - if executor == nvFuserExecutor: - pytest.xfail("nvFuser does not support writing to intermediate results") - def fn(x): y = x.view(-1) y.add_(1) From d765875500e775f50511495edc883dbcbcd880dd Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Fri, 5 Dec 2025 09:15:51 -0800 Subject: [PATCH 06/15] Add regressed case as xfailed test --- thunder/tests/test_update_aliases.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/thunder/tests/test_update_aliases.py b/thunder/tests/test_update_aliases.py index 3e8de8b344..83d0f899b3 100644 --- a/thunder/tests/test_update_aliases.py +++ b/thunder/tests/test_update_aliases.py @@ -365,6 +365,29 @@ def fn(x): torch.testing.assert_close(actual, expected) +@instantiate( + dtypes=NOTHING, + decorators=( + pytest.mark.xfail( + reason="Writing to viewed intermediate. See https://github.com/Lightning-AI/lightning-thunder/issues/2766" + ), + pytest.mark.parametrize("requires_grad", (False, True)), + ), +) +def test_write_to_viewed_intermediate(executor, device, dtype, requires_grad): + def fn(a): + b = a * 2 + c = b[:] + c.tanh_() + return a * b + + a = make_tensor((2, 3), dtype=torch.float32, device=device, requires_grad=requires_grad) + jfn = executor.make_callable(fn, fusion_type="dataflow") + actual = jfn(a) + expected = fn(a) + torch.testing.assert_close(actual, expected) + + @instantiate( dtypes=(dtypes.float32,), ) From cc595a12c66a1570f35776f504e460a8750d78c0 Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Fri, 5 Dec 2025 10:24:06 -0800 Subject: [PATCH 07/15] Add output of update_aliases to encountered --- thunder/core/update_aliases.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thunder/core/update_aliases.py b/thunder/core/update_aliases.py index 49dd81b3cd..d9ac527ded 100644 --- a/thunder/core/update_aliases.py +++ b/thunder/core/update_aliases.py @@ -194,7 +194,7 @@ def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[li if has_tags(bsym, {BoundSymbolTag.BACKWARD}): update_bsym.tags.add(BoundSymbolTag.BACKWARD) bsyms.append(update_bsym) - encountered.update(out_tensors) + encountered.update(out_tensors, map(variableify, new_aliases)) bsyms.append(new_bsym) if _is_inplace_op(bsym) and len(out_tensors) == 1 and len(in_tensors) == 1: # This relies on these being one element sets (ltorch.setitem_ yields no outs). From a0fbd456a351b8dc5973f5af24819c456bae651c Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Fri, 5 Dec 2025 10:25:03 -0800 Subject: [PATCH 08/15] Fix xfail: Apply variable renaming to view_groups --- thunder/core/update_aliases.py | 19 +++++++++++++++---- thunder/tests/test_update_aliases.py | 7 +------ 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/thunder/core/update_aliases.py b/thunder/core/update_aliases.py index d9ac527ded..be498a0f1a 100644 --- a/thunder/core/update_aliases.py +++ b/thunder/core/update_aliases.py @@ -166,8 +166,8 @@ def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[li # Third pass: insert alias updates for bsym in computation_trace.bound_symbols: + bsym = bsym.from_bsym_swap_proxies(swap_map) if _is_inplace_op(bsym) or _is_view_creation_op(bsym) or _involves_viewed_args(bsym, viewed): - bsym = bsym.from_bsym_swap_proxies(swap_map, skip_output=True) in_tensors = list(map(variableify, filter(lambda p: isinstance(p, TensorProxy), bsym.flat_proxy_args))) if _is_inplace_op(bsym) and in_tensors: in_tensors = {in_tensors[0]} @@ -175,8 +175,9 @@ def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[li in_tensors = set(in_tensors) out_tensors = set(map(variableify, filter(lambda p: isinstance(p, TensorProxy), bsym.flat_proxy_outs))) encountered.update(in_tensors) - group = set(reduce(set.union, filter(lambda g: any(g.intersection(in_tensors)), view_groups), set())) - views_encountered = group.intersection(encountered) + involved_view_groups = [g for g in view_groups if g.intersection(in_tensors)] + involved_views = set().union(*involved_view_groups) + views_encountered = involved_views.intersection(encountered) if _is_inplace_op(bsym): # This is a hack to insert fusion break because nvFuser doesn't support mutation on intermediates @@ -200,8 +201,18 @@ def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[li # This relies on these being one element sets (ltorch.setitem_ yields no outs). swap_map = _update_swap_map(swap_map, in_tensors.pop(), unvariableify(out_tensors.pop())) + # views_encountered and new_aliases refer to the same variables in the original trace, + # so we update view groups to use the latest variables in the new trace + variable_renames = { + alias: variableify(new_alias) for alias, new_alias in zip(views_encountered, new_aliases) + } + for i, group in enumerate(involved_view_groups): + new_group = {variable_renames.get(t, t) for t in group} + view_groups[i] = new_group + viewed = set().union(*view_groups) + else: - bsyms.append(bsym.from_bsym_swap_proxies(swap_map)) + bsyms.append(bsym) alias_updated_trace = from_trace(computation_trace) alias_updated_trace.set_provenance(TraceProvenance("Update aliases for in-place ops")) diff --git a/thunder/tests/test_update_aliases.py b/thunder/tests/test_update_aliases.py index 83d0f899b3..1eb5a2d16e 100644 --- a/thunder/tests/test_update_aliases.py +++ b/thunder/tests/test_update_aliases.py @@ -367,12 +367,7 @@ def fn(x): @instantiate( dtypes=NOTHING, - decorators=( - pytest.mark.xfail( - reason="Writing to viewed intermediate. See https://github.com/Lightning-AI/lightning-thunder/issues/2766" - ), - pytest.mark.parametrize("requires_grad", (False, True)), - ), + decorators=(pytest.mark.parametrize("requires_grad", (False, True)),), ) def test_write_to_viewed_intermediate(executor, device, dtype, requires_grad): def fn(a): From 2f6609ce2a01f8c0510805400c6e66ddc149054e Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Fri, 5 Dec 2025 10:32:38 -0800 Subject: [PATCH 09/15] Don't depend on consistency of set iteration order --- thunder/core/update_aliases.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/thunder/core/update_aliases.py b/thunder/core/update_aliases.py index be498a0f1a..20d6215449 100644 --- a/thunder/core/update_aliases.py +++ b/thunder/core/update_aliases.py @@ -177,11 +177,11 @@ def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[li encountered.update(in_tensors) involved_view_groups = [g for g in view_groups if g.intersection(in_tensors)] involved_views = set().union(*involved_view_groups) - views_encountered = involved_views.intersection(encountered) + views_encountered = tuple(involved_views.intersection(encountered)) if _is_inplace_op(bsym): # This is a hack to insert fusion break because nvFuser doesn't support mutation on intermediates - views_encountered.update(in_tensors) + views_encountered = tuple(in_tensors.union(views_encountered)) if not views_encountered: # This is a view creation with operands that are not involved in any inplace ops. From 4cee1ad14d0557c67bcc0034a85adb41caf3e31e Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Thu, 11 Dec 2025 21:01:05 -0800 Subject: [PATCH 10/15] Unswap instead of updating view groups --- thunder/core/update_aliases.py | 20 +++++++------------- thunder/tests/test_update_aliases.py | 6 ++++++ 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/thunder/core/update_aliases.py b/thunder/core/update_aliases.py index 9ed54b1f77..4734565ef4 100644 --- a/thunder/core/update_aliases.py +++ b/thunder/core/update_aliases.py @@ -144,9 +144,6 @@ def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[li if not any(_is_inplace_op(bsym) for bsym in computation_trace.bound_symbols): return computation_trace - swap_map = dict() - bsyms = [] - # First pass: identify inputs which are views of each other and swap them out with a default, # reshaping if necessary. computation_trace, view_groups = replace_args_with_alias_map(computation_trace, alias_tensor_indices) @@ -173,11 +170,15 @@ def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[li view_groups = [group for group in view_groups if len(group.intersection(inplace_inputs)) != 0] viewed = set(reduce(set.union, view_groups, set())) + swap_map = dict() + swap_map_by_update_aliases = dict() + bsyms = [] + # Third pass: insert alias updates for bsym in computation_trace.bound_symbols: bsym = bsym.from_bsym_swap_proxies(swap_map) in_tensors = list(map(variableify, filter(lambda p: isinstance(p, TensorProxy), bsym.flat_proxy_args))) - unswapped_in_tensors = _unswap(swap_map, in_tensors) + unswapped_in_tensors = _unswap(swap_map_by_update_aliases, in_tensors) if ( _is_inplace_op(bsym) or _is_view_creation_op(bsym) @@ -216,15 +217,8 @@ def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[li # This relies on these being one element sets (ltorch.setitem_ yields no outs). swap_map = _update_swap_map(swap_map, in_tensors.pop(), unvariableify(out_tensors.pop())) - # views_encountered and new_aliases refer to the same variables in the original trace, - # so we update view groups to use the latest variables in the new trace - variable_renames = { - alias: variableify(new_alias) for alias, new_alias in zip(views_encountered, new_aliases) - } - for i, group in enumerate(involved_view_groups): - new_group = {variable_renames.get(t, t) for t in group} - view_groups[i] = new_group - viewed = set().union(*view_groups) + for alias, new_alias in zip(views_encountered, new_aliases): + _update_swap_map(swap_map_by_update_aliases, alias, new_alias) else: bsyms.append(bsym) diff --git a/thunder/tests/test_update_aliases.py b/thunder/tests/test_update_aliases.py index f4d7793d66..6febea4a5b 100644 --- a/thunder/tests/test_update_aliases.py +++ b/thunder/tests/test_update_aliases.py @@ -575,9 +575,15 @@ def g(x): x.cos_() return x * x * x * x + def h(x): + y = x[:] + y.sin_() + return x * x * x * x + expected_num_update_aliases = { f: 1, # before sin_ g: 2, # before sin_ and cos_; latter is a hack to cause fusion break + h: 5, # before sin_ and every mul } for fn in [f, g]: From 61a125ed2d2f20179d00f2d779481a612d5570ba Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Thu, 11 Dec 2025 23:22:45 -0800 Subject: [PATCH 11/15] Workaround to always put return bsym at last --- thunder/executors/nvfuserex_impl.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py index 8c2d0f6324..6a0dbb007e 100644 --- a/thunder/executors/nvfuserex_impl.py +++ b/thunder/executors/nvfuserex_impl.py @@ -817,8 +817,13 @@ def map_redundant(x: Any) -> Any: new_symbols = [new_bsyms.get(bsym, bsym) for bsym in trace.bound_symbols] cse_trace.bound_symbols = list(filterfalse(lambda a: a is None, new_symbols)) - return_bsym = cse_trace.bound_symbols[-1] - assert return_bsym.sym.id == prims.PrimIDs.RETURN + return_bsym = None + for idx, bsym in enumerate(cse_trace.bound_symbols): + if bsym.sym.id == prims.PrimIDs.RETURN: + return_bsym = cse_trace.bound_symbols.pop(idx) + break + assert return_bsym is not None + trace_output = tree_map(map_redundant, return_bsym.args) cse_trace.bound_symbols[-1] = prims.python_return.bind(*trace_output, output=None) From 912adbd9c67c9d06c0aecd9269db625084813557 Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Thu, 11 Dec 2025 23:24:11 -0800 Subject: [PATCH 12/15] fixup --- thunder/executors/nvfuserex_impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py index 6a0dbb007e..d53deb27ae 100644 --- a/thunder/executors/nvfuserex_impl.py +++ b/thunder/executors/nvfuserex_impl.py @@ -825,7 +825,7 @@ def map_redundant(x: Any) -> Any: assert return_bsym is not None trace_output = tree_map(map_redundant, return_bsym.args) - cse_trace.bound_symbols[-1] = prims.python_return.bind(*trace_output, output=None) + cse_trace.bound_symbols.append(prims.python_return.bind(*trace_output, output=None)) end_time_ns = time.perf_counter_ns() elapsed_time_ns = end_time_ns - start_time_ns From ed75fe1872a9e839b06a4e4f2e55296082d07517 Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Thu, 11 Dec 2025 23:31:08 -0800 Subject: [PATCH 13/15] Add comment --- thunder/core/update_aliases.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/thunder/core/update_aliases.py b/thunder/core/update_aliases.py index 4734565ef4..448bb4e9ae 100644 --- a/thunder/core/update_aliases.py +++ b/thunder/core/update_aliases.py @@ -178,6 +178,8 @@ def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[li for bsym in computation_trace.bound_symbols: bsym = bsym.from_bsym_swap_proxies(swap_map) in_tensors = list(map(variableify, filter(lambda p: isinstance(p, TensorProxy), bsym.flat_proxy_args))) + # We do not unswap out_tensor of an inplace bsym into in_tensor, because functional dependency is already + # captured by that reference to out_tensor unswapped_in_tensors = _unswap(swap_map_by_update_aliases, in_tensors) if ( _is_inplace_op(bsym) From 3627ee6ef4f4b7a825933439e2e86cbe5a168157 Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Fri, 12 Dec 2025 05:03:28 -0800 Subject: [PATCH 14/15] Add comment --- thunder/executors/nvfuserex_impl.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py index d53deb27ae..8930c1f082 100644 --- a/thunder/executors/nvfuserex_impl.py +++ b/thunder/executors/nvfuserex_impl.py @@ -817,6 +817,8 @@ def map_redundant(x: Any) -> Any: new_symbols = [new_bsyms.get(bsym, bsym) for bsym in trace.bound_symbols] cse_trace.bound_symbols = list(filterfalse(lambda a: a is None, new_symbols)) + # TODO: Remove this and assert that return_bsym is at the end of the trace + # This is a temporary workaround until https://github.com/Lightning-AI/lightning-thunder/issues/2776 is fixed return_bsym = None for idx, bsym in enumerate(cse_trace.bound_symbols): if bsym.sym.id == prims.PrimIDs.RETURN: From 1f01b01cb938ea8bec3009db97f8ee372239bba1 Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Wed, 17 Dec 2025 01:44:24 -0800 Subject: [PATCH 15/15] Empty commit