Skip to content

shape incompatibility when training SAE #11

@alw399

Description

@alw399

Hello!

I was trying to use the scripts to train my own SAE model (examples/train_basic_sae.py or examples/train_multiple_sae_architectures.py --architecture fidelity), but had a few errors, possibly due to package inconsistencies. I have transformers==5.3.0 and nnsight==0.5.8 which seems fine according to the the requirements file. However, the larger issue comes in get_esm_output_with_intervention, where submodule.input seems to be of shape (batch, seq_len, hidden), while hidden_state_override is only (hidden). I've made this change, but I'm entirely sure that it's correct. Please let me know!

embd_to_patch = (
    # submodule.input[0][0]
    submodule.input
    if input_or_output == "input"
    else submodule.output
)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions