diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6ed2eda..03a4376 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -20,7 +20,7 @@ repos: - id: jupyter-clear-output name: jupyter-clear-output files: \.ipynb$ - stages: [commit] + stages: [pre-commit] language: system entry: jupyter nbconvert --ClearOutputPreprocessor.enabled=True --inplace diff --git a/lib/util/util/subject.py b/lib/util/util/subject.py index a9023e3..bc50974 100644 --- a/lib/util/util/subject.py +++ b/lib/util/util/subject.py @@ -302,7 +302,7 @@ def collect_acts( if "mlp_out_BTD" in include: layer_acts["mlp_out_BTD"] = self.mlps[layer].output.detach().save() if "attn_out_BTD" in include: - layer_acts["attn_out_BTD"] = self.attns[layer].output.detach().save() + layer_acts["attn_out_BTD"] = self.attns[layer].output[0].detach().save() if "attn_map_BQTT" in include: layer_acts["attn_map_BQTT"] = self.attns[layer].output[1].detach().save() if "neurons_BTI" in include: