Skip to content

API Reference

Tensor Saving

torchspy.saver.TensorSaver

Bases: BaseDebugger

Utility to save intermediate tensors during PyTorch forward passes.

This class registers module paths and provides infrastructure for saving tensors from within module forward methods via the spy_save() function.

Attributes:

Name Type Description
output_dir Path

Directory where tensors are saved.

enabled bool

Whether debugging is active.

call_counts dict[str, int]

Tracks call count per tensor name.

module_paths dict[int, str]

Maps module id to its path string.

Example

from torchspy import TensorSaver, DebugContext, spy_save

Setup saver

saver = TensorSaver("./debug_tensors") saver.register_modules(model, target_classes=(AttentionLayer,))

Run with debug context

with DebugContext(saver, prefix="step0"): ... output = model(inputs)

Inside your module's forward(), call spy_save():

spy_save("q", q, self) # saves as {prefix}.{module_path}.q.call{n}.pt

Source code in src/torchspy/saver.py
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
class TensorSaver(BaseDebugger):
    """Utility to save intermediate tensors during PyTorch forward passes.

    This class registers module paths and provides infrastructure for saving
    tensors from within module forward methods via the spy_save() function.

    Attributes:
        output_dir (Path): Directory where tensors are saved.
        enabled (bool): Whether debugging is active.
        call_counts (dict[str, int]): Tracks call count per tensor name.
        module_paths (dict[int, str]): Maps module id to its path string.

    Example:
        >>> from torchspy import TensorSaver, DebugContext, spy_save
        >>>
        >>> # Setup saver
        >>> saver = TensorSaver("./debug_tensors")
        >>> saver.register_modules(model, target_classes=(AttentionLayer,))
        >>>
        >>> # Run with debug context
        >>> with DebugContext(saver, prefix="step0"):
        ...     output = model(inputs)
        >>>
        >>> # Inside your module's forward(), call spy_save():
        >>> # spy_save("q", q, self)  # saves as {prefix}.{module_path}.q.call{n}.pt

    """

    def __init__(self, output_dir: str | Path, enabled: bool = True) -> None:
        """Initialize the tensor saver.

        Args:
            output_dir (str | Path): Directory to save tensor files.
            enabled (bool): Whether saving is enabled. Defaults to True.

        """
        super().__init__(output_dir, enabled)
        self.call_counts: dict[str, int] = defaultdict(int)

    def register_modules(
        self,
        model: nn.Module,
        target_classes: tuple[Type[nn.Module], ...] = (nn.Module,),
        target_names: list[str] | None = None,
        exclude_names: list[str] | None = None,
    ) -> None:
        """Register module paths for target modules.

        This populates the module_paths mapping so that spy_save() can
        look up the module's path in the model hierarchy.

        Args:
            model (nn.Module): The root model to inspect.
            target_classes (tuple[Type[nn.Module], ...]): Register modules
                that are instances of these classes.
            target_names (list[str] | None): Register modules whose path
                contains any of these substrings.
            exclude_names (list[str] | None): Exclude modules whose path
                contains any of these substrings.

        """
        for name, module in model.named_modules():
            module_path = name or "root"
            if self._should_register_module(
                module, module_path, target_classes, target_names, exclude_names
            ):
                self.module_paths[id(module)] = module_path
                logger.info("Registered module path: %s", module_path)

    def save(
        self,
        name: str,
        tensor: Tensor,
        call_idx: int | None = None,
        norm_only: bool = False,
    ) -> None:
        """Save a tensor to disk.

        Args:
            name (str): The tensor name (will be used in filename).
            tensor (Tensor): The tensor to save.
            call_idx (int | None): Call index. If None, auto-increments
                based on name.
            norm_only (bool): If True, save only the L2 norm of the tensor
                (flattened per batch). Defaults to False.

        """
        if not self.enabled:
            return

        if call_idx is None:
            call_idx = self.call_counts[name]
            self.call_counts[name] += 1

        filename = f"{name}.call{call_idx}.pt"
        path = self.output_dir / filename
        if norm_only:
            tensor = tensor.view(tensor.size(0), -1)
            tensor = LA.norm(tensor, dim=-1)
        torch.save(tensor.detach().cpu(), path)
        logger.debug("Saved tensor: %s", path)

    def reset_counts(self) -> None:
        """Reset all call counters.

        Call this between batches if you want per-batch indexing.

        """
        self.call_counts.clear()

__init__(output_dir, enabled=True)

Initialize the tensor saver.

Parameters:

Name Type Description Default
output_dir str | Path

Directory to save tensor files.

required
enabled bool

Whether saving is enabled. Defaults to True.

True
Source code in src/torchspy/saver.py
50
51
52
53
54
55
56
57
58
59
def __init__(self, output_dir: str | Path, enabled: bool = True) -> None:
    """Initialize the tensor saver.

    Args:
        output_dir (str | Path): Directory to save tensor files.
        enabled (bool): Whether saving is enabled. Defaults to True.

    """
    super().__init__(output_dir, enabled)
    self.call_counts: dict[str, int] = defaultdict(int)

register_modules(model, target_classes=(nn.Module,), target_names=None, exclude_names=None)

Register module paths for target modules.

This populates the module_paths mapping so that spy_save() can look up the module's path in the model hierarchy.

Parameters:

Name Type Description Default
model Module

The root model to inspect.

required
target_classes tuple[Type[Module], ...]

Register modules that are instances of these classes.

(Module,)
target_names list[str] | None

Register modules whose path contains any of these substrings.

None
exclude_names list[str] | None

Exclude modules whose path contains any of these substrings.

None
Source code in src/torchspy/saver.py
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def register_modules(
    self,
    model: nn.Module,
    target_classes: tuple[Type[nn.Module], ...] = (nn.Module,),
    target_names: list[str] | None = None,
    exclude_names: list[str] | None = None,
) -> None:
    """Register module paths for target modules.

    This populates the module_paths mapping so that spy_save() can
    look up the module's path in the model hierarchy.

    Args:
        model (nn.Module): The root model to inspect.
        target_classes (tuple[Type[nn.Module], ...]): Register modules
            that are instances of these classes.
        target_names (list[str] | None): Register modules whose path
            contains any of these substrings.
        exclude_names (list[str] | None): Exclude modules whose path
            contains any of these substrings.

    """
    for name, module in model.named_modules():
        module_path = name or "root"
        if self._should_register_module(
            module, module_path, target_classes, target_names, exclude_names
        ):
            self.module_paths[id(module)] = module_path
            logger.info("Registered module path: %s", module_path)

save(name, tensor, call_idx=None, norm_only=False)

Save a tensor to disk.

Parameters:

Name Type Description Default
name str

The tensor name (will be used in filename).

required
tensor Tensor

The tensor to save.

required
call_idx int | None

Call index. If None, auto-increments based on name.

None
norm_only bool

If True, save only the L2 norm of the tensor (flattened per batch). Defaults to False.

False
Source code in src/torchspy/saver.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def save(
    self,
    name: str,
    tensor: Tensor,
    call_idx: int | None = None,
    norm_only: bool = False,
) -> None:
    """Save a tensor to disk.

    Args:
        name (str): The tensor name (will be used in filename).
        tensor (Tensor): The tensor to save.
        call_idx (int | None): Call index. If None, auto-increments
            based on name.
        norm_only (bool): If True, save only the L2 norm of the tensor
            (flattened per batch). Defaults to False.

    """
    if not self.enabled:
        return

    if call_idx is None:
        call_idx = self.call_counts[name]
        self.call_counts[name] += 1

    filename = f"{name}.call{call_idx}.pt"
    path = self.output_dir / filename
    if norm_only:
        tensor = tensor.view(tensor.size(0), -1)
        tensor = LA.norm(tensor, dim=-1)
    torch.save(tensor.detach().cpu(), path)
    logger.debug("Saved tensor: %s", path)

reset_counts()

Reset all call counters.

Call this between batches if you want per-batch indexing.

Source code in src/torchspy/saver.py
124
125
126
127
128
129
130
def reset_counts(self) -> None:
    """Reset all call counters.

    Call this between batches if you want per-batch indexing.

    """
    self.call_counts.clear()

torchspy.context.DebugContext

Context manager for scoped tensor debugging.

Use this to add a prefix to saved tensors and enable spy_save() calls within module forward methods.

Attributes:

Name Type Description
saver TensorSaver

The saver instance to use.

prefix str

Prefix added to all tensor names in this context.

module_path_override str | None

Override module path for spy_save().

Example

from torchspy import TensorSaver, DebugContext

saver = TensorSaver("./debug_tensors") with DebugContext(saver, prefix="batch0_step0"): ... output = model(inputs) ... # Inside forward: spy_save("q", q, self) ... # Saves as: batch0_step0.{module_path}.q.call0.pt

Source code in src/torchspy/context.py
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
class DebugContext:
    """Context manager for scoped tensor debugging.

    Use this to add a prefix to saved tensors and enable spy_save() calls
    within module forward methods.

    Attributes:
        saver (TensorSaver): The saver instance to use.
        prefix (str): Prefix added to all tensor names in this context.
        module_path_override (str | None): Override module path for spy_save().

    Example:
        >>> from torchspy import TensorSaver, DebugContext
        >>>
        >>> saver = TensorSaver("./debug_tensors")
        >>> with DebugContext(saver, prefix="batch0_step0"):
        ...     output = model(inputs)
        ...     # Inside forward: spy_save("q", q, self)
        ...     # Saves as: batch0_step0.{module_path}.q.call0.pt

    """

    def __init__(
        self,
        saver: "TensorSaver",
        prefix: str = "",
        module_path_override: str | None = None,
    ) -> None:
        """Initialize the debug context.

        Args:
            saver (TensorSaver): The saver instance.
            prefix (str): Prefix for tensor names. Defaults to "".
            module_path_override (str | None): Override the module path.
                Useful when calling spy_save() from helper functions.

        """
        self.saver = saver
        self.prefix = prefix
        self.module_path_override = module_path_override
        self._token: contextvars.Token | None = None

    # Backward compatibility alias
    @property
    def debugger(self) -> "TensorSaver":
        """Backward compatibility alias for saver."""
        return self.saver

    def __enter__(self) -> "DebugContext":
        """Enter the debug context."""
        self._token = _debug_context.set(self)
        return self

    def __exit__(self, exc_type, exc_val, exc_tb) -> None:
        """Exit the debug context."""
        if self._token is not None:
            _debug_context.reset(self._token)

    def _save(self, name: str, tensor: Tensor, module: nn.Module | None = None) -> None:
        """Save a tensor with context-aware naming.

        Args:
            name (str): The tensor variable name (e.g., "q", "attn_mask").
            tensor (Tensor): The tensor to save.
            module (nn.Module | None): The module saving this tensor.
                Used to look up the module path.

        """
        if self.module_path_override is not None:
            module_path = self.module_path_override
        elif module is not None:
            module_path = self.saver.get_module_path(module)
        else:
            module_path = "manual"

        if module_path == "unknown":
            logger.warning("Module path unknown for tensor '%s'. Skipping save.", name)
            return

        full_name = (
            f"{self.prefix}.{module_path}.{name}"
            if self.prefix
            else f"{module_path}.{name}"
        )
        self.saver.save(full_name, tensor)

__init__(saver, prefix='', module_path_override=None)

Initialize the debug context.

Parameters:

Name Type Description Default
saver TensorSaver

The saver instance.

required
prefix str

Prefix for tensor names. Defaults to "".

''
module_path_override str | None

Override the module path. Useful when calling spy_save() from helper functions.

None
Source code in src/torchspy/context.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def __init__(
    self,
    saver: "TensorSaver",
    prefix: str = "",
    module_path_override: str | None = None,
) -> None:
    """Initialize the debug context.

    Args:
        saver (TensorSaver): The saver instance.
        prefix (str): Prefix for tensor names. Defaults to "".
        module_path_override (str | None): Override the module path.
            Useful when calling spy_save() from helper functions.

    """
    self.saver = saver
    self.prefix = prefix
    self.module_path_override = module_path_override
    self._token: contextvars.Token | None = None

__enter__()

Enter the debug context.

Source code in src/torchspy/context.py
73
74
75
76
def __enter__(self) -> "DebugContext":
    """Enter the debug context."""
    self._token = _debug_context.set(self)
    return self

__exit__(exc_type, exc_val, exc_tb)

Exit the debug context.

Source code in src/torchspy/context.py
78
79
80
81
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
    """Exit the debug context."""
    if self._token is not None:
        _debug_context.reset(self._token)

torchspy.saver.spy_save(name, tensor, module=None)

Save a tensor if a debug context is active.

This is a convenience function to call from within module forward methods. If no debug context is active, this function does nothing (no-op).

Parameters:

Name Type Description Default
name str

The tensor variable name (e.g., "q", "k", "attn_mask").

required
tensor Tensor

The tensor to save.

required
module Module | None

The module instance (self). Used to determine the module path. Pass from within a module's forward().

None
Example

class MyModule(nn.Module): ... def forward(self, x): ... q = self.proj_q(x) ... spy_save("q", q, self) ... return q

Source code in src/torchspy/saver.py
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
def spy_save(name: str, tensor: Tensor, module: nn.Module | None = None) -> None:
    """Save a tensor if a debug context is active.

    This is a convenience function to call from within module forward methods.
    If no debug context is active, this function does nothing (no-op).

    Args:
        name (str): The tensor variable name (e.g., "q", "k", "attn_mask").
        tensor (Tensor): The tensor to save.
        module (nn.Module | None): The module instance (self). Used to determine
            the module path. Pass  from within a module's forward().

    Example:
        >>> class MyModule(nn.Module):
        ...     def forward(self, x):
        ...         q = self.proj_q(x)
        ...         spy_save("q", q, self)
        ...         return q

    """
    ctx = get_debug_context()
    if ctx is not None:
        ctx._save(name, tensor, module)

torchspy.context.get_debug_context()

Get the current debug context.

Returns:

Type Description
DebugContext | None

DebugContext | None: The active context, or None if no context is active.

Source code in src/torchspy/context.py
112
113
114
115
116
117
118
119
def get_debug_context() -> DebugContext | None:
    """Get the current debug context.

    Returns:
        DebugContext | None: The active context, or None if no context is active.

    """
    return _debug_context.get()

Call Tracing

torchspy.tracer.CallTracer

Bases: BaseDebugger

Traces the execution order of PyTorch modules using forward hooks.

This class registers forward hooks on target modules to automatically record their call order without requiring any modifications to the module code.

Attributes:

Name Type Description
output_dir Path

Directory where trace files are saved.

enabled bool

Whether tracing is active.

call_trace list[str]

List of module paths in call order.

module_paths dict[int, str]

Maps module id to its path string.

hooks list

List of registered hook handles for cleanup.

Example

from torchspy import CallTracer

tracer = CallTracer("./debug_traces") tracer.register_hooks(model, target_classes=(nn.Linear, nn.MultiheadAttention))

Run forward pass - hooks automatically record call order

output = model(inputs)

Save the trace

tracer.save_trace("forward_pass.txt")

Or get the trace as a list

print(tracer.call_trace)

Source code in src/torchspy/tracer.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
class CallTracer(BaseDebugger):
    """Traces the execution order of PyTorch modules using forward hooks.

    This class registers forward hooks on target modules to automatically
    record their call order without requiring any modifications to the
    module code.

    Attributes:
        output_dir (Path): Directory where trace files are saved.
        enabled (bool): Whether tracing is active.
        call_trace (list[str]): List of module paths in call order.
        module_paths (dict[int, str]): Maps module id to its path string.
        hooks (list): List of registered hook handles for cleanup.

    Example:
        >>> from torchspy import CallTracer
        >>>
        >>> tracer = CallTracer("./debug_traces")
        >>> tracer.register_hooks(model, target_classes=(nn.Linear, nn.MultiheadAttention))
        >>>
        >>> # Run forward pass - hooks automatically record call order
        >>> output = model(inputs)
        >>>
        >>> # Save the trace
        >>> tracer.save_trace("forward_pass.txt")
        >>> # Or get the trace as a list
        >>> print(tracer.call_trace)

    """

    def __init__(self, output_dir: str | Path, enabled: bool = True) -> None:
        """Initialize the call tracer.

        Args:
            output_dir (str | Path): Directory to save trace files.
            enabled (bool): Whether tracing is enabled. Defaults to True.

        """
        super().__init__(output_dir, enabled)
        self.call_trace: list[str] = []
        self.hooks: list[Any] = []

    def register_hooks(
        self,
        model: nn.Module,
        target_classes: tuple[Type[nn.Module], ...] = (nn.Module,),
        target_names: list[str] | None = None,
        exclude_names: list[str] | None = None,
    ) -> None:
        """Register forward hooks on target modules.

        This method walks through the model hierarchy and registers forward
        hooks on modules that match the specified criteria.

        Args:
            model (nn.Module): The root model to inspect.
            target_classes (tuple[Type[nn.Module], ...]): Register hooks on
                modules that are instances of these classes. Defaults to (nn.Module,).
            target_names (list[str] | None): Register hooks on modules whose
                path contains any of these substrings.
            exclude_names (list[str] | None): Exclude modules whose path
                contains any of these substrings.

        """
        for name, module in model.named_modules():
            module_path = name or "root"
            if self._should_register_module(
                module, module_path, target_classes, target_names, exclude_names
            ):
                self.module_paths[id(module)] = module_path
                hook = module.register_forward_hook(self._create_hook(module_path))
                self.hooks.append(hook)
                logger.info("Registered hook on: %s", module_path)

    def _create_hook(self, module_path: str) -> Callable:
        """Create a forward hook that records the module path.

        Args:
            module_path (str): The path of the module in the model hierarchy.

        Returns:
            Callable: A hook function that records the module call.

        """

        def hook(module: nn.Module, input: Any, output: Any) -> None:  # noqa: A002
            if self.enabled:
                self.call_trace.append(module_path)

        return hook

    def save_trace(self, filename: str = "call_trace.txt") -> Path:
        """Save the call trace to a text file.

        Args:
            filename (str): Name of the output file. Defaults to "call_trace.txt".

        Returns:
            Path: Path to the saved trace file.

        """
        path = self.output_dir / filename
        with open(path, "w") as f:
            f.write("\n".join(self.call_trace))
        logger.info("Saved call trace to: %s (%d calls)", path, len(self.call_trace))
        return path

    def get_trace(self) -> list[str]:
        """Get the current call trace.

        Returns:
            list[str]: List of module paths in call order.

        """
        return self.call_trace.copy()

    def reset_trace(self) -> None:
        """Clear the recorded call trace."""
        self.call_trace.clear()

    def remove_hooks(self) -> None:
        """Remove all registered forward hooks."""
        for hook in self.hooks:
            hook.remove()
        self.hooks.clear()
        logger.info("Removed all hooks")

    def __enter__(self) -> "CallTracer":
        """Enter context - enable tracing."""
        self.enabled = True
        self.reset_trace()
        return self

    def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
        """Exit context - optionally save trace."""
        pass

    def __del__(self) -> None:
        """Clean up hooks on deletion."""
        self.remove_hooks()

__init__(output_dir, enabled=True)

Initialize the call tracer.

Parameters:

Name Type Description Default
output_dir str | Path

Directory to save trace files.

required
enabled bool

Whether tracing is enabled. Defaults to True.

True
Source code in src/torchspy/tracer.py
48
49
50
51
52
53
54
55
56
57
58
def __init__(self, output_dir: str | Path, enabled: bool = True) -> None:
    """Initialize the call tracer.

    Args:
        output_dir (str | Path): Directory to save trace files.
        enabled (bool): Whether tracing is enabled. Defaults to True.

    """
    super().__init__(output_dir, enabled)
    self.call_trace: list[str] = []
    self.hooks: list[Any] = []

register_hooks(model, target_classes=(nn.Module,), target_names=None, exclude_names=None)

Register forward hooks on target modules.

This method walks through the model hierarchy and registers forward hooks on modules that match the specified criteria.

Parameters:

Name Type Description Default
model Module

The root model to inspect.

required
target_classes tuple[Type[Module], ...]

Register hooks on modules that are instances of these classes. Defaults to (nn.Module,).

(Module,)
target_names list[str] | None

Register hooks on modules whose path contains any of these substrings.

None
exclude_names list[str] | None

Exclude modules whose path contains any of these substrings.

None
Source code in src/torchspy/tracer.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
def register_hooks(
    self,
    model: nn.Module,
    target_classes: tuple[Type[nn.Module], ...] = (nn.Module,),
    target_names: list[str] | None = None,
    exclude_names: list[str] | None = None,
) -> None:
    """Register forward hooks on target modules.

    This method walks through the model hierarchy and registers forward
    hooks on modules that match the specified criteria.

    Args:
        model (nn.Module): The root model to inspect.
        target_classes (tuple[Type[nn.Module], ...]): Register hooks on
            modules that are instances of these classes. Defaults to (nn.Module,).
        target_names (list[str] | None): Register hooks on modules whose
            path contains any of these substrings.
        exclude_names (list[str] | None): Exclude modules whose path
            contains any of these substrings.

    """
    for name, module in model.named_modules():
        module_path = name or "root"
        if self._should_register_module(
            module, module_path, target_classes, target_names, exclude_names
        ):
            self.module_paths[id(module)] = module_path
            hook = module.register_forward_hook(self._create_hook(module_path))
            self.hooks.append(hook)
            logger.info("Registered hook on: %s", module_path)

save_trace(filename='call_trace.txt')

Save the call trace to a text file.

Parameters:

Name Type Description Default
filename str

Name of the output file. Defaults to "call_trace.txt".

'call_trace.txt'

Returns:

Name Type Description
Path Path

Path to the saved trace file.

Source code in src/torchspy/tracer.py
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
def save_trace(self, filename: str = "call_trace.txt") -> Path:
    """Save the call trace to a text file.

    Args:
        filename (str): Name of the output file. Defaults to "call_trace.txt".

    Returns:
        Path: Path to the saved trace file.

    """
    path = self.output_dir / filename
    with open(path, "w") as f:
        f.write("\n".join(self.call_trace))
    logger.info("Saved call trace to: %s (%d calls)", path, len(self.call_trace))
    return path

get_trace()

Get the current call trace.

Returns:

Type Description
list[str]

list[str]: List of module paths in call order.

Source code in src/torchspy/tracer.py
125
126
127
128
129
130
131
132
def get_trace(self) -> list[str]:
    """Get the current call trace.

    Returns:
        list[str]: List of module paths in call order.

    """
    return self.call_trace.copy()

reset_trace()

Clear the recorded call trace.

Source code in src/torchspy/tracer.py
134
135
136
def reset_trace(self) -> None:
    """Clear the recorded call trace."""
    self.call_trace.clear()

remove_hooks()

Remove all registered forward hooks.

Source code in src/torchspy/tracer.py
138
139
140
141
142
143
def remove_hooks(self) -> None:
    """Remove all registered forward hooks."""
    for hook in self.hooks:
        hook.remove()
    self.hooks.clear()
    logger.info("Removed all hooks")

Base Class

torchspy._base.BaseDebugger

Base class for tensor debugging utilities.

Provides common functionality for directory management and module registration that is shared between TensorSaver and CallTracer.

Attributes:

Name Type Description
output_dir Path

Directory where output files are saved.

enabled bool

Whether the debugger is active.

module_paths dict[int, str]

Maps module id to its path string.

Source code in src/torchspy/_base.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
class BaseDebugger:
    """Base class for tensor debugging utilities.

    Provides common functionality for directory management and module
    registration that is shared between TensorSaver and CallTracer.

    Attributes:
        output_dir (Path): Directory where output files are saved.
        enabled (bool): Whether the debugger is active.
        module_paths (dict[int, str]): Maps module id to its path string.

    """

    def __init__(self, output_dir: str | Path, enabled: bool = True) -> None:
        """Initialize the base debugger.

        Args:
            output_dir (str | Path): Directory to save output files.
            enabled (bool): Whether debugging is enabled. Defaults to True.

        """
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)
        self.enabled = enabled
        self.module_paths: dict[int, str] = {}

    def _should_register_module(
        self,
        module: nn.Module,
        module_path: str,
        target_classes: tuple[Type[nn.Module], ...],
        target_names: list[str] | None,
        exclude_names: list[str] | None,
    ) -> bool:
        """Determine if a module should be registered.

        Args:
            module (nn.Module): The module to check.
            module_path (str): The module's path in the model hierarchy.
            target_classes (tuple[Type[nn.Module], ...]): Register if instance of these.
            target_names (list[str] | None): Register if path contains any of these.
            exclude_names (list[str] | None): Exclude if path contains any of these.

        Returns:
            bool: True if the module should be registered.

        """

        if not isinstance(module, target_classes):
            return False

        if target_names is not None and not any(t in module_path for t in target_names):
            return False

        if exclude_names is not None and any(t in module_path for t in exclude_names):
            return False

        return True

    def get_module_path(self, module: nn.Module) -> str:
        """Get the registered path for a module.

        Args:
            module (nn.Module): The module to look up.

        Returns:
            str: The module's path, or "unknown" if not registered.

        """
        return self.module_paths.get(id(module), "unknown")

__init__(output_dir, enabled=True)

Initialize the base debugger.

Parameters:

Name Type Description Default
output_dir str | Path

Directory to save output files.

required
enabled bool

Whether debugging is enabled. Defaults to True.

True
Source code in src/torchspy/_base.py
29
30
31
32
33
34
35
36
37
38
39
40
def __init__(self, output_dir: str | Path, enabled: bool = True) -> None:
    """Initialize the base debugger.

    Args:
        output_dir (str | Path): Directory to save output files.
        enabled (bool): Whether debugging is enabled. Defaults to True.

    """
    self.output_dir = Path(output_dir)
    self.output_dir.mkdir(parents=True, exist_ok=True)
    self.enabled = enabled
    self.module_paths: dict[int, str] = {}

get_module_path(module)

Get the registered path for a module.

Parameters:

Name Type Description Default
module Module

The module to look up.

required

Returns:

Name Type Description
str str

The module's path, or "unknown" if not registered.

Source code in src/torchspy/_base.py
75
76
77
78
79
80
81
82
83
84
85
def get_module_path(self, module: nn.Module) -> str:
    """Get the registered path for a module.

    Args:
        module (nn.Module): The module to look up.

    Returns:
        str: The module's path, or "unknown" if not registered.

    """
    return self.module_paths.get(id(module), "unknown")