API Reference¶
Tensor Saving¶
torchspy.saver.TensorSaver
¶
Bases: BaseDebugger
Utility to save intermediate tensors during PyTorch forward passes.
This class registers module paths and provides infrastructure for saving tensors from within module forward methods via the spy_save() function.
Attributes:
| Name | Type | Description |
|---|---|---|
output_dir |
Path
|
Directory where tensors are saved. |
enabled |
bool
|
Whether debugging is active. |
call_counts |
dict[str, int]
|
Tracks call count per tensor name. |
module_paths |
dict[int, str]
|
Maps module id to its path string. |
Example
from torchspy import TensorSaver, DebugContext, spy_save
Setup saver¶
saver = TensorSaver("./debug_tensors") saver.register_modules(model, target_classes=(AttentionLayer,))
Run with debug context¶
with DebugContext(saver, prefix="step0"): ... output = model(inputs)
Inside your module's forward(), call spy_save():¶
spy_save("q", q, self) # saves as {prefix}.{module_path}.q.call{n}.pt¶
Source code in src/torchspy/saver.py
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 | |
__init__(output_dir, enabled=True)
¶
Initialize the tensor saver.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
output_dir
|
str | Path
|
Directory to save tensor files. |
required |
enabled
|
bool
|
Whether saving is enabled. Defaults to True. |
True
|
Source code in src/torchspy/saver.py
50 51 52 53 54 55 56 57 58 59 | |
register_modules(model, target_classes=(nn.Module,), target_names=None, exclude_names=None)
¶
Register module paths for target modules.
This populates the module_paths mapping so that spy_save() can look up the module's path in the model hierarchy.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
The root model to inspect. |
required |
target_classes
|
tuple[Type[Module], ...]
|
Register modules that are instances of these classes. |
(Module,)
|
target_names
|
list[str] | None
|
Register modules whose path contains any of these substrings. |
None
|
exclude_names
|
list[str] | None
|
Exclude modules whose path contains any of these substrings. |
None
|
Source code in src/torchspy/saver.py
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 | |
save(name, tensor, call_idx=None, norm_only=False)
¶
Save a tensor to disk.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
The tensor name (will be used in filename). |
required |
tensor
|
Tensor
|
The tensor to save. |
required |
call_idx
|
int | None
|
Call index. If None, auto-increments based on name. |
None
|
norm_only
|
bool
|
If True, save only the L2 norm of the tensor (flattened per batch). Defaults to False. |
False
|
Source code in src/torchspy/saver.py
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 | |
reset_counts()
¶
Reset all call counters.
Call this between batches if you want per-batch indexing.
Source code in src/torchspy/saver.py
124 125 126 127 128 129 130 | |
torchspy.context.DebugContext
¶
Context manager for scoped tensor debugging.
Use this to add a prefix to saved tensors and enable spy_save() calls within module forward methods.
Attributes:
| Name | Type | Description |
|---|---|---|
saver |
TensorSaver
|
The saver instance to use. |
prefix |
str
|
Prefix added to all tensor names in this context. |
module_path_override |
str | None
|
Override module path for spy_save(). |
Example
from torchspy import TensorSaver, DebugContext
saver = TensorSaver("./debug_tensors") with DebugContext(saver, prefix="batch0_step0"): ... output = model(inputs) ... # Inside forward: spy_save("q", q, self) ... # Saves as: batch0_step0.{module_path}.q.call0.pt
Source code in src/torchspy/context.py
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 | |
__init__(saver, prefix='', module_path_override=None)
¶
Initialize the debug context.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
saver
|
TensorSaver
|
The saver instance. |
required |
prefix
|
str
|
Prefix for tensor names. Defaults to "". |
''
|
module_path_override
|
str | None
|
Override the module path. Useful when calling spy_save() from helper functions. |
None
|
Source code in src/torchspy/context.py
47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 | |
__enter__()
¶
Enter the debug context.
Source code in src/torchspy/context.py
73 74 75 76 | |
__exit__(exc_type, exc_val, exc_tb)
¶
Exit the debug context.
Source code in src/torchspy/context.py
78 79 80 81 | |
torchspy.saver.spy_save(name, tensor, module=None)
¶
Save a tensor if a debug context is active.
This is a convenience function to call from within module forward methods. If no debug context is active, this function does nothing (no-op).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
The tensor variable name (e.g., "q", "k", "attn_mask"). |
required |
tensor
|
Tensor
|
The tensor to save. |
required |
module
|
Module | None
|
The module instance (self). Used to determine the module path. Pass from within a module's forward(). |
None
|
Example
class MyModule(nn.Module): ... def forward(self, x): ... q = self.proj_q(x) ... spy_save("q", q, self) ... return q
Source code in src/torchspy/saver.py
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 | |
torchspy.context.get_debug_context()
¶
Get the current debug context.
Returns:
| Type | Description |
|---|---|
DebugContext | None
|
DebugContext | None: The active context, or None if no context is active. |
Source code in src/torchspy/context.py
112 113 114 115 116 117 118 119 | |
Call Tracing¶
torchspy.tracer.CallTracer
¶
Bases: BaseDebugger
Traces the execution order of PyTorch modules using forward hooks.
This class registers forward hooks on target modules to automatically record their call order without requiring any modifications to the module code.
Attributes:
| Name | Type | Description |
|---|---|---|
output_dir |
Path
|
Directory where trace files are saved. |
enabled |
bool
|
Whether tracing is active. |
call_trace |
list[str]
|
List of module paths in call order. |
module_paths |
dict[int, str]
|
Maps module id to its path string. |
hooks |
list
|
List of registered hook handles for cleanup. |
Example
from torchspy import CallTracer
tracer = CallTracer("./debug_traces") tracer.register_hooks(model, target_classes=(nn.Linear, nn.MultiheadAttention))
Run forward pass - hooks automatically record call order¶
output = model(inputs)
Save the trace¶
tracer.save_trace("forward_pass.txt")
Or get the trace as a list¶
print(tracer.call_trace)
Source code in src/torchspy/tracer.py
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 | |
__init__(output_dir, enabled=True)
¶
Initialize the call tracer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
output_dir
|
str | Path
|
Directory to save trace files. |
required |
enabled
|
bool
|
Whether tracing is enabled. Defaults to True. |
True
|
Source code in src/torchspy/tracer.py
48 49 50 51 52 53 54 55 56 57 58 | |
register_hooks(model, target_classes=(nn.Module,), target_names=None, exclude_names=None)
¶
Register forward hooks on target modules.
This method walks through the model hierarchy and registers forward hooks on modules that match the specified criteria.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
The root model to inspect. |
required |
target_classes
|
tuple[Type[Module], ...]
|
Register hooks on modules that are instances of these classes. Defaults to (nn.Module,). |
(Module,)
|
target_names
|
list[str] | None
|
Register hooks on modules whose path contains any of these substrings. |
None
|
exclude_names
|
list[str] | None
|
Exclude modules whose path contains any of these substrings. |
None
|
Source code in src/torchspy/tracer.py
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 | |
save_trace(filename='call_trace.txt')
¶
Save the call trace to a text file.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
filename
|
str
|
Name of the output file. Defaults to "call_trace.txt". |
'call_trace.txt'
|
Returns:
| Name | Type | Description |
|---|---|---|
Path |
Path
|
Path to the saved trace file. |
Source code in src/torchspy/tracer.py
109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 | |
get_trace()
¶
Get the current call trace.
Returns:
| Type | Description |
|---|---|
list[str]
|
list[str]: List of module paths in call order. |
Source code in src/torchspy/tracer.py
125 126 127 128 129 130 131 132 | |
reset_trace()
¶
Clear the recorded call trace.
Source code in src/torchspy/tracer.py
134 135 136 | |
remove_hooks()
¶
Remove all registered forward hooks.
Source code in src/torchspy/tracer.py
138 139 140 141 142 143 | |
Base Class¶
torchspy._base.BaseDebugger
¶
Base class for tensor debugging utilities.
Provides common functionality for directory management and module registration that is shared between TensorSaver and CallTracer.
Attributes:
| Name | Type | Description |
|---|---|---|
output_dir |
Path
|
Directory where output files are saved. |
enabled |
bool
|
Whether the debugger is active. |
module_paths |
dict[int, str]
|
Maps module id to its path string. |
Source code in src/torchspy/_base.py
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 | |
__init__(output_dir, enabled=True)
¶
Initialize the base debugger.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
output_dir
|
str | Path
|
Directory to save output files. |
required |
enabled
|
bool
|
Whether debugging is enabled. Defaults to True. |
True
|
Source code in src/torchspy/_base.py
29 30 31 32 33 34 35 36 37 38 39 40 | |
get_module_path(module)
¶
Get the registered path for a module.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
module
|
Module
|
The module to look up. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
str |
str
|
The module's path, or "unknown" if not registered. |
Source code in src/torchspy/_base.py
75 76 77 78 79 80 81 82 83 84 85 | |