Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 69 additions & 15 deletions fbgemm_gpu/codegen/genscript/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,6 +1029,34 @@ def partial_rowwise_lamb() -> Dict[str, Any]:

def adam() -> Dict[str, Any]:
split_precomputation = """
// Define the optimizer state (for use with optimizer offloading)
struct OptimizerState {
// momentum1 is an array of D values beginning at the optimizer state address
DEVICE_INLINE momentum1_ph_t* momentum1_ptr() {
return reinterpret_cast<momentum1_ph_t *>(this);
}

// momentum2 is an array of D values laid out immediately after the momentum1 array
DEVICE_INLINE momentum2_ph_t* momentum2_ptr(const int32_t D) {
// Cast to uintptr_t for pointer arithmetic
auto addr = reinterpret_cast<uintptr_t>(momentum1_ptr() + D);

// Cast back to momentum2_ph_t* and return
return reinterpret_cast<momentum2_ph_t *>(addr);
}
};

// Fetch the pointer to the optimizer state along the cache row
[[maybe_unused]] auto* optimizer = weight_row_template.template optimizer_state_ptr<OptimizerState>();

// Fetch the pointers to momentum1 and momentum2 buffers
auto* momentum1_start = enable_optimizer_offloading ?
(optimizer->momentum1_ptr()) :
(&momentum1[idx * D]);
auto* momentum2_start = enable_optimizer_offloading ?
(optimizer->momentum2_ptr(D)) :
(&momentum2[idx * D]);

at::acc_type<cache_t, true>* __restrict__ row_counter;
at::acc_type<cache_t, true> _row_counter = iter;
if (use_rowwise_bias_correction) {
Expand All @@ -1052,20 +1080,37 @@ def adam() -> Dict[str, Any]:
"""

split_weight_update = """
Vec4T<cache_t> m_t(&momentum1[idx * D + d]);
m_t.mul_(beta1);
m_t.fma_(grad, 1.0 - beta1);
m_t.store(&momentum1[idx * D + d]);
// Determine the access pointers for momentum1 and momentum2
auto* momentum1_ptr = momentum1_start + d;
auto* momentum2_ptr = momentum2_start + d;

Vec4T<momentum1_ph_t> m_t;
Vec4T<momentum2_ph_t> v_t;

if (enable_optimizer_offloading) {
m_t = vec4_load_unaligned(momentum1_ptr);
m_t.mul_(beta1);
m_t.fma_(grad, 1.0 - beta1);
vec4_store_unaligned(m_t, momentum1_ptr);

v_t = vec4_load_unaligned(momentum2_ptr);
v_t.mul_(beta2);
grad.sq_();
v_t.fma_(grad, 1.0 - beta2);
vec4_store_unaligned(v_t, momentum2_ptr);

Vec4T<cache_t> v_t(&momentum2[idx * D + d]);
v_t.mul_(beta2);
} else {
m_t = Vec4T<momentum1_ph_t>(momentum1_ptr);
m_t.mul_(beta1);
m_t.fma_(grad, 1.0 - beta1);
m_t.store(momentum1_ptr);

grad.acc.x *= grad.acc.x;
grad.acc.y *= grad.acc.y;
grad.acc.z *= grad.acc.z;
grad.acc.w *= grad.acc.w;
v_t.fma_(grad, 1.0 - beta2);
v_t.store(&momentum2[idx * D + d]);
v_t = Vec4T<momentum2_ph_t>(momentum2_ptr);
v_t.mul_(beta2);
grad.sq_();
v_t.fma_(grad, 1.0 - beta2);
v_t.store(momentum2_ptr);
}

weight_new.acc.x -= learning_rate * (m_t.acc.x / (1.0 - powf(beta1, _row_counter)) / (sqrtf((v_t.acc.x / (1.0 - powf(beta2, _row_counter)))) + eps) + weight_decay * weight_new.acc.x);
weight_new.acc.y -= learning_rate * (m_t.acc.y / (1.0 - powf(beta1, _row_counter)) / (sqrtf((v_t.acc.y / (1.0 - powf(beta2, _row_counter)))) + eps) + weight_decay * weight_new.acc.y);
Expand All @@ -1079,8 +1124,16 @@ def adam() -> Dict[str, Any]:
"is_experimental_optimizer": True,
"args": OptimizerArgsSet.create(
[
OptimItem(ArgType.TENSOR, "momentum1"),
OptimItem(ArgType.TENSOR, "momentum2"),
OptimItem(
ArgType.PLACEHOLDER_TENSOR,
"momentum1",
ph_tys=[ArgType.FLOAT_TENSOR, ArgType.BFLOAT16_TENSOR],
),
OptimItem(
ArgType.PLACEHOLDER_TENSOR,
"momentum2",
ph_tys=[ArgType.FLOAT_TENSOR, ArgType.BFLOAT16_TENSOR],
),
OptimItem(ArgType.TENSOR, "learning_rate_tensor"),
OptimItem(ArgType.FLOAT, "eps"),
OptimItem(ArgType.FLOAT, "beta1"),
Expand All @@ -1102,7 +1155,7 @@ def adam() -> Dict[str, Any]:
"has_gpu_support": True,
"has_vbe_support": True,
"has_global_weight_decay_support": False,
"has_ssd_support": False,
"has_ssd_support": True,
}


Expand All @@ -1119,6 +1172,7 @@ def partial_rowwise_adam() -> Dict[str, Any]:
grad->w * grad->w;
"""
)

split_precomputation += """

// Define the optimizer state (for use with optimizer offloading)
Expand Down
21 changes: 21 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/utils/vec4.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,13 @@ struct Vec4T<float> : public Vec4BaseT<float> {
acc.w *= scale;
}

DEVICE_INLINE void sq_() {
acc.x *= acc.x;
acc.y *= acc.y;
acc.z *= acc.z;
acc.w *= acc.w;
}

// this <- this element-wise mul a
DEVICE_INLINE void element_wise_mul_(const Vec4T<float>& a) {
acc.x *= a.acc.x;
Expand Down Expand Up @@ -426,6 +433,13 @@ struct Vec4T<at::Half> : public Vec4BaseT<at::Half> {
acc.z *= scale;
acc.w *= scale;
}

DEVICE_INLINE void sq_() {
acc.x *= acc.x;
acc.y *= acc.y;
acc.z *= acc.z;
acc.w *= acc.w;
}
};

////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -593,6 +607,13 @@ struct Vec4T<at::BFloat16> : public Vec4BaseT<at::BFloat16> {
acc.z *= scale;
acc.w *= scale;
}

DEVICE_INLINE void sq_() {
acc.x *= acc.x;
acc.y *= acc.y;
acc.z *= acc.z;
acc.w *= acc.w;
}
};

////////////////////////////////////////////////////////////////////////////////
Expand Down
Loading