Skip to content

Commit e5db928

Browse files
coderabbitai[bot]marcuscollins
authored andcommitted
📝 CodeRabbit Chat: Implement requested code changes
1 parent adff4ae commit e5db928

2 files changed

Lines changed: 19 additions & 5 deletions

File tree

src/sampleworks/core/samplers/edm.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,9 @@ def __post_init__(self) -> None:
137137
class AF3EDMSampler:
138138
"""EDM-style sampler from AF3-like models.
139139
140-
All constants are configurable via constructor for model-specific values.
141-
Default values match AF3 parameterization.
140+
Initialized with a single :class:`EDMSamplerConfig` object that holds all
141+
schedule hyperparameters and runtime options. Default values in the config
142+
match the AF3 parameterization.
142143
143144
This sampler implements the EDM (Karras et al.) style sampling
144145
approach as used in AlphaFold3 and related models, which is the Euler
@@ -154,6 +155,18 @@ class AF3EDMSampler:
154155
"""
155156

156157
def __init__(self, config: EDMSamplerConfig) -> None:
158+
"""Initialize the sampler with a configuration object.
159+
160+
Parameters
161+
----------
162+
config : EDMSamplerConfig
163+
Configuration object containing all schedule hyperparameters
164+
(``sigma_data``, ``s_max``, ``s_min``, ``p``, ``gamma_min``,
165+
``gamma_0``, ``noise_scale``, ``step_scale``) and runtime flags
166+
(``augmentation``, ``align_to_input``,
167+
``alignment_reverse_diffusion``, ``scale_guidance_to_diffusion``,
168+
``device``).
169+
"""
157170
self.config = config
158171

159172
def check_context(self, context: StepParams) -> None:
@@ -508,4 +521,4 @@ def step(
508521
denoised=x_hat_0_working_frame_t,
509522
loss=loss,
510523
log_proposal_correction=log_proposal_correction,
511-
)
524+
)

tests/conftest.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,8 @@ def create_sampler_from_type(
253253
if device is not None:
254254
config_kwargs["device"] = device
255255
config = config_cls(**config_kwargs)
256-
return create_component_from_info(info, config=config)
256+
cls = _import_from_path(info.module_path)
257+
return cls(config)
257258
return create_component_from_info(info, device=device, **extra_kwargs)
258259

259260

@@ -1067,4 +1068,4 @@ def perturbed_coords(
10671068
torch.manual_seed(42)
10681069
base = converging_mock_wrapper.target
10691070
perturbation = torch.randn_like(base) * 0.1 # ty: ignore[invalid-argument-type]
1070-
return base, base + perturbation # ty: ignore[invalid-return-type, unsupported-operator]
1071+
return base, base + perturbation # ty: ignore[invalid-return-type, unsupported-operator]

0 commit comments

Comments
 (0)