@@ -137,8 +137,9 @@ def __post_init__(self) -> None:
137137class AF3EDMSampler :
138138 """EDM-style sampler from AF3-like models.
139139
140- All constants are configurable via constructor for model-specific values.
141- Default values match AF3 parameterization.
140+ Initialized with a single :class:`EDMSamplerConfig` object that holds all
141+ schedule hyperparameters and runtime options. Default values in the config
142+ match the AF3 parameterization.
142143
143144 This sampler implements the EDM (Karras et al.) style sampling
144145 approach as used in AlphaFold3 and related models, which is the Euler
@@ -154,6 +155,18 @@ class AF3EDMSampler:
154155 """
155156
156157 def __init__ (self , config : EDMSamplerConfig ) -> None :
158+ """Initialize the sampler with a configuration object.
159+
160+ Parameters
161+ ----------
162+ config : EDMSamplerConfig
163+ Configuration object containing all schedule hyperparameters
164+ (``sigma_data``, ``s_max``, ``s_min``, ``p``, ``gamma_min``,
165+ ``gamma_0``, ``noise_scale``, ``step_scale``) and runtime flags
166+ (``augmentation``, ``align_to_input``,
167+ ``alignment_reverse_diffusion``, ``scale_guidance_to_diffusion``,
168+ ``device``).
169+ """
157170 self .config = config
158171
159172 def check_context (self , context : StepParams ) -> None :
@@ -508,4 +521,4 @@ def step(
508521 denoised = x_hat_0_working_frame_t ,
509522 loss = loss ,
510523 log_proposal_correction = log_proposal_correction ,
511- )
524+ )
0 commit comments