File size: 5,400 Bytes
179cd55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
"""Self-contained subset of :mod:`circuit_sparsity.hook_utils` for inference builds.

The full module has no exotic dependencies, but mirroring the definitions here
keeps the trimmed :mod:`circuit_sparsity.inference.gpt` module hermetic and easy to vendor.  The
implementations below are copied with minor tweaks for readability so that code
written against :func:`hook_recorder`, :func:`hook_namespace`, and
:func:`torch_recompute_preserving_hook_context` behaves identically in both the
training and inference configurations.
"""

from __future__ import annotations

import re
from contextlib import contextmanager
from functools import partial

import torch
import torch.utils.checkpoint


class HookContext:
    """State container used by the hook helpers."""

    def __init__(self) -> None:
        self._reset()
        self.curintervtransformer = lambda x: x

    def _reset(self) -> None:
        self.curcontext = None
        self.curname = ""
        self.curregex = None
        self.curinterventions = None
        self.save_grads = None

    def _get_interventions(self):
        return self.curintervtransformer(
            self.curinterventions if self.curinterventions is not None else {}
        )

    @contextmanager
    def hook_recorder(self, regex: str = ".*", interventions=None, save_grads: bool = False):
        """Record tensors that pass through hooks matching ``regex``."""

        assert self.curcontext is None, "reentrancy not allowed!"

        try:
            self.curcontext = {}
            self.curregex = re.compile(regex)
            self.curname = ""
            self.curinterventions = interventions
            self.save_grads = save_grads

            yield self.curcontext
        finally:
            self._reset()
            get_context()._reset()

    @contextmanager
    def hook_intervention_transform(self, intervention_transformer):
        oldintervention_transformer = self.curintervtransformer

        def compose(f, g):
            return lambda x: f(g(x))

        self.curintervtransformer = compose(
            intervention_transformer,
            self.curintervtransformer,
        )

        try:
            yield
        finally:
            self.curintervtransformer = oldintervention_transformer

    @contextmanager
    def hook_namespace(self, name: str):
        """Temporarily push ``name`` onto the hook namespace stack."""

        oldname = self.curname
        self.curname = self.curname + name + "."

        try:
            yield
        finally:
            self.curname = oldname

    def hook_save(self, name: str, tensor: torch.Tensor) -> torch.Tensor:
        """Optionally record ``tensor`` using the current namespace."""

        curinterventions = self._get_interventions()
        if curinterventions is not None:
            key = self.curname + name
            if key in curinterventions:
                tensor = curinterventions[key](tensor)

        if self.curcontext is not None and self.curregex.match(self.curname + name):
            self.curcontext[self.curname + name] = tensor

        if self.curcontext is not None and self.save_grads and tensor.requires_grad:

            class _Grad(torch.autograd.Function):
                @staticmethod
                def forward(ctx, input_tensor):
                    return input_tensor

                @staticmethod
                def backward(ctx, grad_output):
                    self.curcontext[self.curname + name + ".grad"] = grad_output
                    return grad_output

            if self.curregex.match(self.curname + name + ".grad"):
                tensor = _Grad.apply(tensor)

        return tensor


def set_context(new_context: HookContext) -> None:
    global context
    context = new_context


def get_context() -> HookContext:
    global context
    return context


def torch_recompute_preserving_hook_context(f, *xs, use_reentrant=None):
    """Wrapper around :func:`torch.utils.checkpoint` that propagates hooks."""

    oldcontext = get_context()
    curcontext = HookContext()
    curcontext.curcontext = (
        dict(oldcontext.curcontext) if oldcontext.curcontext is not None else None
    )
    curcontext.curregex = oldcontext.curregex
    curcontext.curname = oldcontext.curname
    curcontext.curinterventions = (
        dict(oldcontext.curinterventions) if oldcontext.curinterventions is not None else None
    )
    curcontext.save_grads = oldcontext.save_grads

    is_recompute = False

    def _f(curcontext: HookContext, *xs):
        initcontext = get_context()
        nonlocal is_recompute

        set_context(curcontext)
        try:
            res = f(*xs)

            if not is_recompute and oldcontext.curcontext is not None:
                oldcontext.curcontext |= curcontext.curcontext
        finally:
            set_context(initcontext)
            is_recompute = True
        return res

    res = torch.utils.checkpoint.checkpoint(
        partial(_f, curcontext), *xs, use_reentrant=use_reentrant
    )

    return res


context = HookContext()


def hook_recorder(*a, **k):
    return get_context().hook_recorder(*a, **k)


def hook_namespace(*a, **k):
    return get_context().hook_namespace(*a, **k)


def hook_save(*a, **k):
    return get_context().hook_save(*a, **k)


def hook_intervention_transform(*a, **k):
    return get_context().hook_intervention_transform(*a, **k)