diff --git a/.gitignore b/.gitignore index 476214a5..8cb8ca35 100755 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,7 @@ build .vscode/** # crash dumps core.* +*.egg-info +*.mudmp +*.whl +*.so diff --git a/README.en.md b/README.en.md index 73c8e22b..aa36a725 100755 --- a/README.en.md +++ b/README.en.md @@ -10,6 +10,7 @@ TensorFlow MUSA Extension is a high-performance TensorFlow plugin specifically d - **Seamless Integration**: Fully compatible with TensorFlow ecosystem without requiring code modifications - **Device Management**: Complete MUSA device registration, memory management, and stream processing support - **Kernel Debugging Support**: Built-in kernel execution time statistics for performance analysis +- **Python Package Support**: Provides `tensorflow_musa` Python package with pip installation and optimizer interface ## Quick Start @@ -18,12 +19,20 @@ TensorFlow MUSA Extension is a high-performance TensorFlow plugin specifically d ``` tensorflow_musa_extension/ ├── CMakeLists.txt # CMake build configuration -├── build.sh # Build script +├── build.sh # Build script (supports release/debug/wheel) +├── setup.py # Python package build configuration ├── .clang-format # Code formatting configuration ├── .pre-commit-config.yaml # pre-commit hook configuration -├── .gitlab-ci.yml # CI/CD configuration +├── .github/ # CI/CD configuration +├── python/ # Python package source directory (pip name: tensorflow_musa) +│ ├── __init__.py # Package entry, auto-loads plugin +│ ├── _loader.py # Plugin loading utilities +│ ├── _patch.py # tf.keras.optimizers.Adam monkey patch +│ └── optimizer/ # Optimizer module +│ ├── __init__.py +│ └── adam.py # MUSA Adam optimizer (supports sparse update) ├── musa_ext/ # Core source directory -│ ├── kernels/ # MUSA kernel implementations +│ ├── kernels/ # MUSA kernel implementations (.mu files) │ ├── mu/ # MUSA device and optimizer implementations │ └── utils/ # Utility functions └── test/ # Test cases @@ -45,61 +54,93 @@ tensorflow_musa_extension/ - Default installation path: `/usr/local/musa` - **Python Dependencies**: - Python: >= 3.7 - - TensorFlow: == 2.6.1 - - protobuf: == 3.20.3 + - TensorFlow: == 2.6.1 (required version) - NumPy: >= 1.19.0 - - prettytable: >= 3.0.0 - **Development Tools**: - pre-commit >= 3.0.0 - pytest >= 6.0.0 -### Installation +### Installation Methods + +#### Method 1: Install WHL Package (Recommended) ```bash # Clone the repository git clone cd tensorflow_musa_extension -# Build the plugin -./build.sh +# Ensure TensorFlow 2.6.1 is installed +pip install tensorflow==2.6.1 + +# Build WHL package (one-click build) +./build.sh wheel + +# Install WHL package +pip install dist/tensorflow_musa-0.1.0-py3-none-any.whl --no-deps + +# Install WHL packages after rebuilding +pip install dist/tensorflow_musa-0.1.0-py3-none-any.whl --no-deps --force-reinstall +``` + +#### Method 2: Development Mode + +```bash +# Clone the repository +git clone +cd tensorflow_musa_extension + +# Build plugin +./build.sh release -# Load the plugin in Python +# Load plugin in Python for testing import tensorflow as tf tf.load_library("./build/libmusa_plugin.so") ``` ## Build Guide -### 1. Build Type +### 1. Build Modes -Both Release and Debug modes are supported: +Three build modes are supported: | Mode | Command | Description | |------|---------|-------------| -| **Release** | `./build.sh` or `./build.sh release` | Optimized for performance, no debug overhead | +| **Release** | `./build.sh` or `./build.sh release` | Optimized performance, generates `build/libmusa_plugin.so` | | **Debug** | `./build.sh debug` | Enables `MUSA_KERNEL_DEBUG` and kernel timing macros | +| **Wheel** | `./build.sh wheel` | One-click WHL package build, generates `dist/tensorflow_musa-*.whl` | ### 2. Compilation Process -Execute the automated build script: - ```bash -# Release (default) +# Release (default) - build plugin only ./build.sh -# Release (explicit) -./build.sh release - # Debug (timing instrumentation) ./build.sh debug + +# Wheel (build release package) +./build.sh wheel ``` -The build script automatically completes the following steps: -- Configures the CMake project +The build script automatically: +- Checks TensorFlow version (must be 2.6.1) +- Configures CMake project - Compiles MUSA kernels and host code -- Generates the dynamic library `libmusa_plugin.so` +- Generates `libmusa_plugin.so` or WHL package + +### 3. WHL Package Notes + +WHL package build features: +- **No auto-download TensorFlow**: Prevents pip from downloading incompatible versions +- **Version check**: Automatically checks TensorFlow version is 2.6.1 before build +- **Package name mapping**: Source directory is `python/`, but pip package name is `tensorflow_musa` + +After installation: +```python +import tensorflow_musa as tf_musa # Package name remains tensorflow_musa +``` -### 3. Debugging and Diagnostics +### 4. Debugging and Diagnostics For detailed debugging guide, see [docs/DEBUG_GUIDE.md](docs/DEBUG_GUIDE.md), including: @@ -186,6 +227,122 @@ Current version supports the following core operators: - **Data Manipulation**: Reshape, Concat, Gather, StridedSlice, ExpandDims - **Normalization**: LayerNorm, FusedBatchNorm - **Special Operators**: TensorInteraction, BiasAdd, Assign +- **Optimizers**: ResourceApplyAdam, MusaResourceSparseApplyAdam (supports embedding sparse update) + +## Usage Examples + +### Basic Usage + +After installing the `tensorflow_musa` package, the plugin is automatically loaded on import: + +```python +import tensorflow_musa as tf_musa + +# Check version +print(f"TensorFlow MUSA version: {tf_musa.__version__}") + +# View available MUSA devices +devices = tf_musa.get_musa_devices() +print(f"Available MUSA devices: {devices}") +``` + +### Auto Patch tf.keras.optimizers.Adam (Recommended) + +After importing `tensorflow_musa`, `tf.keras.optimizers.Adam` is automatically patched to use MUSA fused kernels. No code changes needed: + +```python +import tensorflow as tf +import tensorflow_musa as tf_musa # Auto patches Adam + +# Create model +model = tf.keras.Sequential([ + tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)), + tf.keras.layers.Dense(10, activation='softmax') +]) + +# Use standard tf.keras.optimizers.Adam (auto patched) +optimizer = tf.keras.optimizers.Adam(learning_rate=0.001) + +# Compile model +model.compile( + optimizer=optimizer, + loss='sparse_categorical_crossentropy', + metrics=['accuracy'] +) + +# Embedding sparse gradients automatically use MusaResourceSparseApplyAdam kernel +``` + +### Explicitly Use MUSA Adam Optimizer + +If you want to explicitly specify MUSA optimizer: + +```python +import tensorflow as tf +import tensorflow_musa as tf_musa + +# Create model +model = tf.keras.Sequential([ + tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)), + tf.keras.layers.Dense(10, activation='softmax') +]) + +# Explicitly use MUSA fused Adam optimizer +optimizer = tf_musa.optimizer.Adam( + learning_rate=0.001, + beta_1=0.9, + beta_2=0.999, + epsilon=1e-7 +) + +# Compile model +model.compile( + optimizer=optimizer, + loss='sparse_categorical_crossentropy', + metrics=['accuracy'] +) +``` + +### Device Management + +```python +import tensorflow as tf +import tensorflow_musa as tf_musa + +# Set specific MUSA device +with tf.device('/device:MUSA:0'): + # Create tensors and compute on MUSA device + a = tf.constant([[1.0, 2.0], [3.0, 4.0]]) + b = tf.constant([[5.0, 6.0], [7.0, 8.0]]) + c = tf.matmul(a, b) + print(c) +``` + +### Embedding Sparse Update Example + +MUSA Adam optimizer supports sparse gradient updates for embedding scenarios: + +```python +import tensorflow as tf +import tensorflow_musa as tf_musa + +# Create embedding variable +vocab_size = 10000 +embedding_dim = 128 +with tf.device('/device:MUSA:0'): + embedding = tf.Variable(tf.zeros([vocab_size, embedding_dim])) + +# Use patched Adam +optimizer = tf.keras.optimizers.Adam(learning_rate=0.001) + +# Simulate embedding lookup sparse gradient +indices = tf.constant([0, 5, 10, 15]) # Word IDs in batch +values = tf.random.normal([4, embedding_dim]) # Corresponding gradients +sparse_grad = tf.IndexedSlices(values, indices) + +# Apply sparse gradient update (auto uses MusaResourceSparseApplyAdam kernel) +optimizer.apply_gradients([(sparse_grad, embedding)]) +``` ## Contribution Guidelines diff --git a/README.md b/README.md index 1fa57653..47fed079 100755 --- a/README.md +++ b/README.md @@ -4,12 +4,13 @@ TensorFlow MUSA Extension 是一个高性能的 TensorFlow 插件,专为摩尔 ## 特性 -- **完整的算子支持**:涵盖深度学习训练和推理所需的核心算子 +- **完整的算子支持**:涵盖深度学习训练和推理所需的核心算子 - **高性能优化**:针对 MUSA 架构进行深度优化,包括内存访问模式和计算效率 - **自动图优化**:支持 Layout 自动转换、算子融合和自动混合精度(AMP) - **无缝集成**:与 TensorFlow 生态系统完全兼容,无需修改现有代码 - **设备管理**:完整的 MUSA 设备注册、内存管理和流式处理支持 - **Kernel 调试支持**:内置 Kernel 执行时间统计功能,便于性能分析 +- **Python 包支持**:提供 `tensorflow_musa` Python 包,支持 pip 安装和优化器接口 ## 快速开始 @@ -18,12 +19,20 @@ TensorFlow MUSA Extension 是一个高性能的 TensorFlow 插件,专为摩尔 ``` tensorflow_musa_extension/ ├── CMakeLists.txt # CMake 构建配置 -├── build.sh # 构建脚本 +├── build.sh # 构建脚本(支持 release/debug/wheel) +├── setup.py # Python 包构建配置 ├── .clang-format # 代码格式化配置 ├── .pre-commit-config.yaml # pre-commit 钩子配置 -├── .gitlab-ci.yml # CI/CD 配置 +├── .github/ # CI/CD 配置 +├── python/ # Python 包源码目录(pip 安装名为 tensorflow_musa) +│ ├── __init__.py # 包入口,自动加载插件 +│ ├── _loader.py # 插件加载工具 +│ ├── _patch.py # tf.keras.optimizers.Adam monkey patch +│ └── optimizer/ # 优化器模块 +│ ├── __init__.py +│ └── adam.py # MUSA Adam 优化器(支持 sparse update) ├── musa_ext/ # 核心源码目录 -│ ├── kernels/ # MUSA 内核实现 +│ ├── kernels/ # MUSA 内核实现(.mu 文件) │ ├── mu/ # MUSA 设备和优化器实现 │ └── utils/ # 工具函数 └── test/ # 测试用例 @@ -45,25 +54,45 @@ tensorflow_musa_extension/ - 默认安装路径: `/usr/local/musa` - **Python 依赖** - Python: >= 3.7 - - TensorFlow: == 2.6.1 - - protobuf: == 3.20.3 + - TensorFlow: == 2.6.1(必须使用此版本) - NumPy: >= 1.19.0 - - pettytable: >= 3.0.0 - **开发工具**: - pre-commit >= 3.0.0 - pytest >= 6.0.0 -### 安装 +### 安装方式 + +#### 方式一:安装 WHL 包(推荐) ```bash # 克隆仓库 git clone cd tensorflow_musa_extension -# 构建插件 -./build.sh +# 确保 TensorFlow 2.6.1 已安装 +pip install tensorflow==2.6.1 + +# 构建 WHL 包(一键构建) +./build.sh wheel + +# 安装 WHL 包 +pip install dist/tensorflow_musa-0.1.0-py3-none-any.whl --no-deps + +# 重新构建后安装 WHL 包 +pip install dist/tensorflow_musa-0.1.0-py3-none-any.whl --no-deps --force-reinstall +``` + +#### 方式二:开发模式 + +```bash +# 克隆仓库 +git clone +cd tensorflow_musa_extension + +# 构建 plugin +./build.sh release -# 在 Python 中加载插件 +# 在 Python 中手动加载插件进行测试 import tensorflow as tf tf.load_library("./build/libmusa_plugin.so") ``` @@ -72,34 +101,46 @@ tf.load_library("./build/libmusa_plugin.so") ### 1. 编译模式 -支持 Release 与 Debug 两种模式: +支持三种构建模式: | 模式 | 命令 | 说明 | |------|------|------| -| **Release** | `./build.sh` 或 `./build.sh release` | 优化性能,无调试开销 | +| **Release** | `./build.sh` 或 `./build.sh release` | 优化性能,生成 `build/libmusa_plugin.so` | | **Debug** | `./build.sh debug` | 开启 `MUSA_KERNEL_DEBUG`,启用 kernel timing 宏 | +| **Wheel** | `./build.sh wheel` | 一键构建 WHL 包,生成 `dist/tensorflow_musa-*.whl` | ### 2. 编译流程 -执行自动化构建脚本: - ```bash -# Release(默认) +# Release(默认)- 仅构建 plugin ./build.sh -# Release(显式) -./build.sh release - # Debug(计时调试) ./build.sh debug + +# Wheel(构建发布包) +./build.sh wheel ``` 构建脚本将自动完成以下步骤: +- 检查 TensorFlow 版本(必须为 2.6.1) - 配置 CMake 项目 - 编译 MUSA 内核和主机代码 -- 生成动态链接库 `libmusa_plugin.so` +- 生成动态链接库 `libmusa_plugin.so` 或 WHL 包 + +### 3. WHL 包说明 + +WHL 包构建特点: +- **不自动下载 TensorFlow**:避免 pip 自动下载不兼容版本 +- **版本检查**:构建前自动检查环境中 TensorFlow 版本是否为 2.6.1 +- **包名映射**:源码目录为 `python/`,安装后包名为 `tensorflow_musa` + +安装后使用: +```python +import tensorflow_musa as tf_musa # 包名仍然是 tensorflow_musa +``` -### 3. 调试与诊断 +### 4. 调试与诊断 详细的调试指南请参阅 [docs/DEBUG_GUIDE.md](docs/DEBUG_GUIDE.md),包含: @@ -186,9 +227,126 @@ python test_runner.py --quiet - **数据操作**:Reshape, Concat, Gather, StridedSlice, ExpandDims - **归一化**:LayerNorm, FusedBatchNorm - **特殊算子**:TensorInteraction, BiasAdd, Assign +- **优化器**:ResourceApplyAdam, MusaResourceSparseApplyAdam(支持 embedding 稀疏更新) ## 使用示例 +### 基本用法 + +安装 `tensorflow_musa` 包后,导入时插件会自动加载: + +```python +import tensorflow_musa as tf_musa + +# 查看版本 +print(f"TensorFlow MUSA 版本: {tf_musa.__version__}") + +# 查看可用的 MUSA 设备 +devices = tf_musa.get_musa_devices() +print(f"可用 MUSA 设备: {devices}") +``` + +### 自动 Patch tf.keras.optimizers.Adam(推荐) + +导入 `tensorflow_musa` 后,`tf.keras.optimizers.Adam` 会自动被 patch, +使用 MUSA 融合内核,无需修改现有代码: + +```python +import tensorflow as tf +import tensorflow_musa as tf_musa # 自动 patch Adam + +# 创建模型 +model = tf.keras.Sequential([ + tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)), + tf.keras.layers.Dense(10, activation='softmax') +]) + +# 使用标准的 tf.keras.optimizers.Adam(已被自动 patch) +optimizer = tf.keras.optimizers.Adam(learning_rate=0.001) + +# 编译模型 +model.compile( + optimizer=optimizer, + loss='sparse_categorical_crossentropy', + metrics=['accuracy'] +) + +# 训练时 embedding 稀疏梯度会自动使用 MusaResourceSparseApplyAdam kernel +``` + +### 显式使用 MUSA Adam 优化器 + +如果需要显式指定 MUSA 优化器: + +```python +import tensorflow as tf +import tensorflow_musa as tf_musa + +# 创建模型 +model = tf.keras.Sequential([ + tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)), + tf.keras.layers.Dense(10, activation='softmax') +]) + +# 显式使用 MUSA 融合 Adam 优化器 +optimizer = tf_musa.optimizer.Adam( + learning_rate=0.001, + beta_1=0.9, + beta_2=0.999, + epsilon=1e-7 +) + +# 编译模型 +model.compile( + optimizer=optimizer, + loss='sparse_categorical_crossentropy', + metrics=['accuracy'] +) +``` + +### 设备管理 + +```python +import tensorflow as tf +import tensorflow_musa as tf_musa + +# 设置使用特定 MUSA 设备 +with tf.device('/device:MUSA:0'): + # 在 MUSA 设备上创建张量和计算 + a = tf.constant([[1.0, 2.0], [3.0, 4.0]]) + b = tf.constant([[5.0, 6.0], [7.0, 8.0]]) + c = tf.matmul(a, b) + print(c) +``` + +### Embedding 稀疏更新示例 + +MUSA Adam 优化器支持稀疏梯度更新,适用于 embedding 场景: + +```python +import tensorflow as tf +import tensorflow_musa as tf_musa + +# 创建 embedding 变量 +vocab_size = 10000 +embedding_dim = 128 +with tf.device('/device:MUSA:0'): + embedding = tf.Variable(tf.zeros([vocab_size, embedding_dim])) + +# 使用 patch 后的 Adam +optimizer = tf.keras.optimizers.Adam(learning_rate=0.001) + +# 模拟 embedding lookup 的稀疏梯度 +indices = tf.constant([0, 5, 10, 15]) # batch 中涉及的词 ID +values = tf.random.normal([4, embedding_dim]) # 对应的梯度 +sparse_grad = tf.IndexedSlices(values, indices) + +# 应用稀疏梯度更新(自动使用 MusaResourceSparseApplyAdam kernel) +optimizer.apply_gradients([(sparse_grad, embedding)]) +``` + +### 更多示例 + 详细使用示例见: [![MUSA Playground](https://img.shields.io/badge/Gitee-TensorFlow_MUSA_Playground-C71D23?style=for-the-badge&logo=gitee&logoColor=white)](https://gitee.com/mthreadsacademy/tensorflow_musa_playground) diff --git a/build.sh b/build.sh index b0a62f00..8e9bd1f9 100755 --- a/build.sh +++ b/build.sh @@ -4,14 +4,35 @@ set -e # ============================================================================ # MUSA Plugin Build Script # Usage: -# ./build.sh [release|debug] +# ./build.sh [release|debug|wheel] # # Examples: -# ./build.sh # Default: release mode +# ./build.sh # Default: release mode (build .so only) # ./build.sh release # Release mode (optimized) # ./build.sh debug # Debug mode (kernel timing enabled) +# ./build.sh wheel # Build wheel package directly (recommended for distribution) # ============================================================================ +# Required TensorFlow version +REQUIRED_TF_VERSION="2.6.1" + +# Function to check TensorFlow version +check_tf_version() { + echo "Checking TensorFlow version..." + python3 -c " +import tensorflow as tf +version = tf.__version__ +required = '${REQUIRED_TF_VERSION}' +if version != required: + print(f'ERROR: TensorFlow version mismatch!') + print(f' Required: {required}') + print(f' Installed: {version}') + print(f' Please install: pip install tensorflow=={required}') + exit(1) +print(f'TensorFlow {version} found - OK') +" || exit 1 +} + # Parse build type from command line argument BUILD_TYPE="${1:-release}" BUILD_TYPE=$(echo "$BUILD_TYPE" | tr '[:upper:]' '[:lower:]') @@ -40,17 +61,55 @@ case "$BUILD_TYPE" in echo " • Use env vars MUSA_TIMING_KERNEL_* to control output" echo "" ;; + wheel) + echo "==========================================" + echo "Building tensorflow_musa Wheel Package" + echo "==========================================" + echo "" + check_tf_version + echo "" + echo "Building wheel package..." + echo "" + + # Clean previous wheel builds + rm -rf build/lib build/bdist.* dist/*.whl 2>/dev/null || true + + # Build wheel using setup.py (no isolation to use existing TF) + python3 setup.py bdist_wheel + + # Find and display the built wheel + WHEEL_FILE=$(ls dist/*.whl 2>/dev/null | head -1) + if [ -n "$WHEEL_FILE" ]; then + echo "" + echo "[SUCCESS] Wheel package built successfully!" + ls -lh "$WHEEL_FILE" + echo "" + echo "==========================================" + echo "Install with:" + echo " pip install $WHEEL_FILE --no-deps" + echo "==========================================" + else + echo "" + echo "[FAIL] Wheel package not found in dist/" + exit 1 + fi + exit 0 + ;; *) echo "Error: Unknown build type '$BUILD_TYPE'" - echo "Usage: ./build.sh [release|debug]" + echo "Usage: ./build.sh [release|debug|wheel]" echo "" echo "Options:" echo " release - Optimized release build (default)" echo " debug - Enable MUSA kernel debug/timing macros" + echo " wheel - Build wheel package for distribution" exit 1 ;; esac +# Check TensorFlow version before building .so +check_tf_version + # Clean previous build if needed rm -rf build @@ -88,4 +147,7 @@ echo "Build Complete!" echo "==========================================" echo "Build Type: $BUILD_TYPE" echo "Plugin: $(pwd)/libmusa_plugin.so" +echo "" +echo "To build wheel package:" +echo " ./build.sh wheel" echo "==========================================" diff --git a/musa_ext/kernels/training/musa_apply_sparse_adam_kernel.mu b/musa_ext/kernels/training/musa_apply_sparse_adam_kernel.mu new file mode 100644 index 00000000..cb6d5efe --- /dev/null +++ b/musa_ext/kernels/training/musa_apply_sparse_adam_kernel.mu @@ -0,0 +1,141 @@ +#include +#include +#include +#include +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wignored-pragmas" +#include "tensorflow/core/framework/bfloat16.h" +#include "tensorflow/core/framework/types.h" +#pragma GCC diagnostic pop + +namespace tensorflow { +namespace musa { + +namespace { +__device__ __forceinline__ float LoadFloat(const float* p) { return *p; } +__device__ __forceinline__ void StoreFloat(float* p, float v) { *p = v; } + +__device__ __forceinline__ float LoadFloat(const Eigen::half* p) { + const __half* h_ptr = reinterpret_cast(p); + return __half2float(*h_ptr); +} + +__device__ __forceinline__ void StoreFloat(Eigen::half* p, float v) { + __half h = __float2half(v); + *reinterpret_cast<__half*>(p) = h; +} + +__device__ __forceinline__ float LoadFloat(const bfloat16* p) { + float res = 0.0f; + uint16_t* b_ptr = (uint16_t*)p; + uint32_t* f_ptr = (uint32_t*)&res; + *f_ptr = (static_cast(*b_ptr)) << 16; + return res; +} + +__device__ __forceinline__ void StoreFloat(bfloat16* p, float v) { + uint32_t* f_ptr = (uint32_t*)&v; + uint16_t b_val = (*f_ptr) >> 16; + *reinterpret_cast(p) = b_val; +} +} // namespace + +template +__global__ void ResourceSparseApplyAdamKernel( + T* __restrict__ var, T* __restrict__ m, T* __restrict__ v, + const T* __restrict__ grad, const IndexT* __restrict__ indices, + const T* __restrict__ lr_ptr, const T* __restrict__ beta1_ptr, + const T* __restrict__ beta2_ptr, const T* __restrict__ epsilon_ptr, + const T* __restrict__ beta1_power_ptr, const T* __restrict__ beta2_power_ptr, + int64_t inner_size, int64_t indices_size) { + const int64_t tid = blockIdx.x * blockDim.x + threadIdx.x; + const int64_t total_elements = indices_size * inner_size; + if (tid >= total_elements) return; + + const int64_t inner_idx = tid % inner_size; + const int64_t indices_idx = tid / inner_size; + const IndexT idx = indices[indices_idx]; + if (idx < 0) return; + + const int64_t var_offset = (int64_t)idx * inner_size + inner_idx; + const int64_t grad_offset = tid; + + // Load gradient and current state values + float g = LoadFloat(&grad[grad_offset]); + float m_val = LoadFloat(&m[var_offset]); + float v_val = LoadFloat(&v[var_offset]); + float var_val = LoadFloat(&var[var_offset]); + + // Load hyperparameters + float lr = LoadFloat(lr_ptr); + float beta1 = LoadFloat(beta1_ptr); + float beta2 = LoadFloat(beta2_ptr); + float epsilon = LoadFloat(epsilon_ptr); + float beta1_power = LoadFloat(beta1_power_ptr); + float beta2_power = LoadFloat(beta2_power_ptr); + + // Compute bias-corrected learning rate + // lr_t = lr * sqrt(1 - beta2^t) / (1 - beta1^t) + // Handle edge case when beta1_power ≈ 1.0 (initial iteration) + float one_minus_beta1_power = 1.0f - beta1_power; + float one_minus_beta2_power = 1.0f - beta2_power; + float lr_t; + if (fabsf(one_minus_beta1_power) < 1e-10f) { + lr_t = lr; // Initial iteration fallback + } else { + lr_t = lr * sqrtf(one_minus_beta2_power) / one_minus_beta1_power; + } + + // Update m: m_t = beta1 * m + (1 - beta1) * g + float one_minus_beta1 = 1.0f - beta1; + float m_new = beta1 * m_val + one_minus_beta1 * g; + + // Update v: v_t = beta2 * v + (1 - beta2) * g^2 + float one_minus_beta2 = 1.0f - beta2; + float v_new = beta2 * v_val + one_minus_beta2 * g * g; + + // Update var: var = var - lr_t * m_new / (sqrt(v_new) + epsilon) + float v_sqrt = sqrtf(v_new); + float var_new = var_val - lr_t * m_new / (v_sqrt + epsilon); + + // Store results + StoreFloat(&m[var_offset], m_new); + StoreFloat(&v[var_offset], v_new); + StoreFloat(&var[var_offset], var_new); +} + +#define OPTIMAL_THREADS 256 +#define OPTIMAL_BLOCKS(n) (((n) + OPTIMAL_THREADS - 1) / OPTIMAL_THREADS) + +template +void LaunchResourceSparseApplyAdamImpl( + T* var, T* m, T* v, const T* grad, const IndexT* indices, + const T* lr, const T* beta1, const T* beta2, const T* epsilon, + const T* beta1_power, const T* beta2_power, + int64_t inner_size, int64_t indices_size, musaStream_t stream) { + int64_t total = inner_size * indices_size; + if (total == 0) return; + ResourceSparseApplyAdamKernel + <<>>( + var, m, v, grad, indices, lr, beta1, beta2, epsilon, + beta1_power, beta2_power, inner_size, indices_size); +} + +#define REGISTER_SPARSE_ADAM_LAUNCHER(T, IndexT) \ + template void LaunchResourceSparseApplyAdamImpl( \ + T* var, T* m, T* v, const T* grad, const IndexT* indices, \ + const T* lr, const T* beta1, const T* beta2, const T* epsilon, \ + const T* beta1_power, const T* beta2_power, \ + int64_t inner_size, int64_t indices_size, musaStream_t stream); + +REGISTER_SPARSE_ADAM_LAUNCHER(float, int32); +REGISTER_SPARSE_ADAM_LAUNCHER(float, int64); +REGISTER_SPARSE_ADAM_LAUNCHER(Eigen::half, int32); +REGISTER_SPARSE_ADAM_LAUNCHER(Eigen::half, int64); +REGISTER_SPARSE_ADAM_LAUNCHER(bfloat16, int32); +REGISTER_SPARSE_ADAM_LAUNCHER(bfloat16, int64); + +#undef REGISTER_SPARSE_ADAM_LAUNCHER + +} // namespace musa +} // namespace tensorflow diff --git a/musa_ext/kernels/training/musa_apply_sparse_adam_op.cc b/musa_ext/kernels/training/musa_apply_sparse_adam_op.cc new file mode 100644 index 00000000..6bf017c7 --- /dev/null +++ b/musa_ext/kernels/training/musa_apply_sparse_adam_op.cc @@ -0,0 +1,232 @@ +#include "../utils_op.h" +#include "tensorflow/core/framework/bfloat16.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/resource_var.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { +namespace musa { + +extern Status PrepareTensorForMusaUpdate(OpKernelContext* ctx, Var* var); + +// ============================================================================ +// Define new TensorFlow op for sparse Adam +// ============================================================================ +REGISTER_OP("MusaResourceSparseApplyAdam") + .Input("var: resource") + .Input("m: resource") + .Input("v: resource") + .Input("beta1_power: T") + .Input("beta2_power: T") + .Input("lr: T") + .Input("beta1: T") + .Input("beta2: T") + .Input("epsilon: T") + .Input("grad: T") + .Input("indices: Tindices") + .Attr("T: {float, half, bfloat16}") + .Attr("Tindices: {int32, int64}") + .Attr("use_locking: bool = false") + .SetShapeFn([](shape_inference::InferenceContext* c) { + return Status::OK(); + }); + +// Custom RAII unlocker to avoid issues with TF's mutex_lock macro +class MutexUnlocker { + public: + explicit MutexUnlocker(mutex* mu) : mu_(mu) {} + MutexUnlocker(MutexUnlocker&& other) noexcept : mu_(other.mu_) { + other.mu_ = nullptr; + } + MutexUnlocker(const MutexUnlocker&) = delete; + MutexUnlocker& operator=(const MutexUnlocker&) = delete; + + ~MutexUnlocker() { + if (mu_ != nullptr) { + mu_->unlock(); + } + } + + private: + mutex* mu_; +}; + +template +extern void LaunchResourceSparseApplyAdamImpl( + T* var, T* m, T* v, const T* grad, const IndexT* indices, const T* lr, + const T* beta1, const T* beta2, const T* epsilon, const T* beta1_power, + const T* beta2_power, int64_t inner_size, int64_t indices_size, + musaStream_t stream); + +template +class MusaResourceSparseApplyAdamOp : public MusaOpKernel { + public: + explicit MusaResourceSparseApplyAdamOp(OpKernelConstruction* ctx) + : MusaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_locking_)); + } + + void Compute(OpKernelContext* ctx) override { + LOG(INFO) << "[debug for timo] calling MusaResourceSparseApplyAdamOp"; + // Lookup resource variables + core::RefCountPtr var; + core::RefCountPtr m_var; + core::RefCountPtr v_var; + OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &var)); + OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &m_var)); + OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 2), &v_var)); + + // Lock all variables (following Ftrl pattern) + std::vector vars = {var.get(), m_var.get(), v_var.get()}; + std::vector mutexes; + for (auto* v : vars) { + mutex* mu = v->mu(); + if (std::find(mutexes.begin(), mutexes.end(), mu) == mutexes.end()) { + mutexes.push_back(mu); + } + } + std::sort(mutexes.begin(), mutexes.end()); + + std::vector locks; + locks.reserve(mutexes.size()); + for (mutex* mu : mutexes) { + mu->lock(); + locks.emplace_back(mu); + } + + // Validate initialization + OP_REQUIRES(ctx, + var->tensor()->IsInitialized() && + m_var->tensor()->IsInitialized() && + v_var->tensor()->IsInitialized(), + errors::FailedPrecondition( + "Sparse Adam variables (var/m/v) not initialized.")); + + // Validate shapes match + Tensor* var_tensor = var->tensor(); + Tensor* m_tensor = m_var->tensor(); + Tensor* v_tensor = v_var->tensor(); + + OP_REQUIRES( + ctx, var_tensor->shape().IsSameSize(m_tensor->shape()), + errors::InvalidArgument("var and m must have the same shape. var: ", + var_tensor->shape().DebugString(), + " m: ", m_tensor->shape().DebugString())); + OP_REQUIRES( + ctx, var_tensor->shape().IsSameSize(v_tensor->shape()), + errors::InvalidArgument("var and v must have the same shape. var: ", + var_tensor->shape().DebugString(), + " v: ", v_tensor->shape().DebugString())); + + // Prepare tensors for update (handle copy-on-write) + OP_REQUIRES_OK(ctx, PrepareTensorForMusaUpdate(ctx, var.get())); + OP_REQUIRES_OK(ctx, PrepareTensorForMusaUpdate(ctx, m_var.get())); + OP_REQUIRES_OK(ctx, PrepareTensorForMusaUpdate(ctx, v_var.get())); + + // Refresh tensor pointers after potential copy + var_tensor = var->tensor(); + m_tensor = m_var->tensor(); + v_tensor = v_var->tensor(); + + // Get hyperparameters (host memory scalars) + const Tensor& beta1_power = ctx->input(3); + const Tensor& beta2_power = ctx->input(4); + const Tensor& lr = ctx->input(5); + const Tensor& beta1 = ctx->input(6); + const Tensor& beta2 = ctx->input(7); + const Tensor& epsilon = ctx->input(8); + const Tensor& grad = ctx->input(9); + const Tensor& indices = ctx->input(10); + + // Validate hyperparameter shapes + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_power.shape()), + errors::InvalidArgument("beta1_power must be scalar: ", + beta1_power.shape().DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2_power.shape()), + errors::InvalidArgument("beta2_power must be scalar: ", + beta2_power.shape().DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()), + errors::InvalidArgument("lr must be scalar: ", + lr.shape().DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1.shape()), + errors::InvalidArgument("beta1 must be scalar: ", + beta1.shape().DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2.shape()), + errors::InvalidArgument("beta2 must be scalar: ", + beta2.shape().DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()), + errors::InvalidArgument("epsilon must be scalar: ", + epsilon.shape().DebugString())); + + // Validate indices and grad shapes + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()), + errors::InvalidArgument("indices must be a vector: ", + indices.shape().DebugString())); + OP_REQUIRES(ctx, grad.dims() > 0, + errors::InvalidArgument("grad must be at least 1D: ", + grad.shape().DebugString())); + OP_REQUIRES(ctx, grad.dim_size(0) == indices.dim_size(0), + errors::InvalidArgument( + "grad and indices dimension 0 must match. grad: ", + grad.shape().DebugString(), + ", indices: ", indices.shape().DebugString())); + + // Validate var is at least 2D for sparse update + OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var_tensor->shape()), + errors::InvalidArgument("var must be at least 1D: ", + var_tensor->shape().DebugString())); + + // Compute sizes + const int64_t inner_size = + var_tensor->shape().num_elements() / var_tensor->dim_size(0); + const int64_t indices_size = indices.dim_size(0); + + musaStream_t stream = GetMusaStreamByCtx(ctx); + + LaunchResourceSparseApplyAdamImpl( + var_tensor->flat().data(), m_tensor->flat().data(), + v_tensor->flat().data(), grad.flat().data(), + indices.flat().data(), lr.flat().data(), + beta1.flat().data(), beta2.flat().data(), + epsilon.flat().data(), beta1_power.flat().data(), + beta2_power.flat().data(), inner_size, indices_size, stream); + + musaError_t sync_err = musaStreamSynchronize(stream); + OP_REQUIRES( + ctx, sync_err == musaSuccess, + errors::Internal( + "MusaResourceSparseApplyAdam: musaStreamSynchronize failed: ", + musaGetErrorString(sync_err))); + } + + private: + bool use_locking_; +}; + +// Register kernels for all dtype combinations +#define REGISTER_KERNELS(T) \ + REGISTER_KERNEL_BUILDER(Name("MusaResourceSparseApplyAdam") \ + .Device(DEVICE_MTGPU) \ + .TypeConstraint("T") \ + .TypeConstraint("Tindices"), \ + MusaResourceSparseApplyAdamOp); \ + REGISTER_KERNEL_BUILDER(Name("MusaResourceSparseApplyAdam") \ + .Device(DEVICE_MTGPU) \ + .TypeConstraint("T") \ + .TypeConstraint("Tindices"), \ + MusaResourceSparseApplyAdamOp); + +REGISTER_KERNELS(float); +REGISTER_KERNELS(Eigen::half); +REGISTER_KERNELS(bfloat16); + +#undef REGISTER_KERNELS + +} // namespace musa +} // namespace tensorflow diff --git a/python/__init__.py b/python/__init__.py new file mode 100644 index 00000000..7a0b9c5b --- /dev/null +++ b/python/__init__.py @@ -0,0 +1,87 @@ +# Copyright 2026 The TensorFlow MUSA Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +""" +TensorFlow MUSA Extension - High-performance TensorFlow plugin for Moore Threads GPUs. + +This package provides: +- Automatic plugin loading on import +- Optimized optimizer implementations (Adam, etc.) using fused MUSA kernels +- Device management utilities +- Monkey patching of tf.keras.optimizers.Adam for transparent MUSA acceleration + +Example usage: + import tensorflow_musa as tf_musa + + # Plugin is automatically loaded on import + # tf.keras.optimizers.Adam is automatically patched to use MUSA kernels + + # Use MUSA-accelerated Adam optimizer (no changes needed!) + optimizer = tf.keras.optimizers.Adam(learning_rate=0.001) + + # Or use explicit MUSA optimizer + optimizer = tf_musa.optimizer.Adam(learning_rate=0.001) + + # Check available MUSA devices + devices = tf_musa.get_musa_devices() +""" + +import logging + +from ._loader import load_plugin, is_plugin_loaded, get_musa_devices, get_musa_ops_module + +# Package version +__version__ = "0.1.0" + +# Load plugin automatically on import +_plugin_loaded = False +_plugin_path = None + +try: + _plugin_path = load_plugin() + _plugin_loaded = True +except Exception as e: + logging.warning(f"Failed to load MUSA plugin: {e}") + logging.warning( + "MUSA functionality will not be available. " + "Please ensure the plugin is built and MUSA SDK is installed." + ) + + +# Import optimizer module after plugin is loaded +from . import optimizer + +# Import patch utilities +from ._patch import patch_keras_adam, unpatch_keras_adam, is_adam_patched + +# Auto-patch tf.keras.optimizers.Adam when MUSA devices are available +if _plugin_loaded and get_musa_devices(): + try: + patch_keras_adam() + except Exception as e: + logging.warning(f"Failed to patch tf.keras.optimizers.Adam: {e}") + +# Public API +__all__ = [ + "__version__", + "load_plugin", + "is_plugin_loaded", + "get_musa_devices", + "get_musa_ops_module", + "optimizer", + "patch_keras_adam", + "unpatch_keras_adam", + "is_adam_patched", +] diff --git a/python/_loader.py b/python/_loader.py new file mode 100644 index 00000000..a6d787e5 --- /dev/null +++ b/python/_loader.py @@ -0,0 +1,156 @@ +# Copyright 2026 The TensorFlow MUSA Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""MUSA plugin loading utilities.""" + +import os +import sys +import logging + +logger = logging.getLogger(__name__) + +# Plugin library name +PLUGIN_LIBRARY = "libmusa_plugin.so" + +# Global module for accessing custom ops (initialized by load_plugin) +_musa_ops_module = None + + +def _find_plugin_library(): + """Find the MUSA plugin shared library. + + Search order: + 1. Package installation directory (next to __init__.py) + 2. Project build directory (for development) + 3. System paths (LD_LIBRARY_PATH, /usr/local/musa/lib) + """ + # Get the directory where this package is installed + package_dir = os.path.dirname(os.path.abspath(__file__)) + + # Candidate paths to search + candidate_paths = [ + # Package installation directory + os.path.join(package_dir, PLUGIN_LIBRARY), + # Build directory relative to package (development mode) + os.path.join(package_dir, "..", "build", PLUGIN_LIBRARY), + # Build directory relative to project root (when running from project) + os.path.join(os.getcwd(), "build", PLUGIN_LIBRARY), + # System MUSA library path + os.path.join("/usr/local/musa", "lib", PLUGIN_LIBRARY), + os.path.join("/usr/local/musa", "lib64", PLUGIN_LIBRARY), + ] + + # Also check LD_LIBRARY_PATH + ld_library_path = os.environ.get("LD_LIBRARY_PATH", "") + for path in ld_library_path.split(os.pathsep): + if path: + candidate_paths.append(os.path.join(path, PLUGIN_LIBRARY)) + + # Search for the library + for path in candidate_paths: + normalized_path = os.path.normpath(path) + if os.path.exists(normalized_path): + return normalized_path + + # If not found, raise an error with helpful message + searched = "\n".join(f" - {os.path.normpath(p)}" for p in candidate_paths) + raise FileNotFoundError( + f"MUSA plugin library '{PLUGIN_LIBRARY}' not found.\n" + f"Searched locations:\n{searched}\n" + f"Please ensure the plugin has been built (run './build.sh' or 'pip install')." + ) + + +def load_plugin(): + """Load the MUSA plugin library into TensorFlow. + + This must be called before using any MUSA-specific operations. + The plugin registers MUSA device and kernels with TensorFlow. + + Returns: + str: Path to the loaded plugin library + + Raises: + FileNotFoundError: If the plugin library cannot be found + RuntimeError: If TensorFlow cannot load the plugin + """ + global _musa_ops_module + import tensorflow as tf + + plugin_path = _find_plugin_library() + + try: + # Load op library to get custom ops (like MusaResourceSparseApplyAdam) + _musa_ops_module = tf.load_op_library(plugin_path) + logger.info(f"MUSA plugin loaded successfully from: {plugin_path}") + return plugin_path + except Exception as e: + raise RuntimeError( + f"Failed to load MUSA plugin from {plugin_path}: {e}\n" + f"Please ensure TensorFlow and MUSA SDK are properly installed." + ) + + +def get_musa_ops_module(): + """Get the module containing custom MUSA ops. + + After load_plugin() is called, this returns a module with + custom ops registered via REGISTER_OP, such as: + - musa_resource_sparse_apply_adam + + Returns: + module: The MUSA ops module, or None if not loaded + """ + return _musa_ops_module + + +def is_plugin_loaded(): + """Check if the MUSA plugin has been loaded. + + Returns: + bool: True if plugin is loaded, False otherwise + """ + import tensorflow as tf + + # Check for MUSA device availability + try: + devices = tf.config.list_physical_devices() + for device in devices: + if "MUSA" in device: + return True + except Exception: + pass + + return False + + +def get_musa_devices(): + """Get list of available MUSA devices. + + Returns: + list: List of MUSA device names + """ + import tensorflow as tf + + musa_devices = [] + try: + devices = tf.config.list_physical_devices() + for device in devices: + if "MUSA" in device: + musa_devices.append(device) + except Exception: + pass + + return musa_devices diff --git a/python/_patch.py b/python/_patch.py new file mode 100644 index 00000000..a89fced0 --- /dev/null +++ b/python/_patch.py @@ -0,0 +1,209 @@ +# Copyright 2026 The TensorFlow MUSA Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Monkey patch utilities for TensorFlow Adam optimizer. + +This module provides functionality to patch tf.keras.optimizers.Adam to use +MUSA fused kernels for both dense and sparse gradient updates. +""" + +import tensorflow as tf + +# Store original methods for restoration +_original_methods = {} + + +def _musa_resource_apply_sparse(self, grad, var, indices, apply_state=None): + """Use fused MusaResourceSparseApplyAdam kernel on MUSA device. + + This method replaces the default TensorFlow Adam's _resource_apply_sparse + which uses multiple ops (assign, scatter_add, etc.) with a single fused + kernel for better performance on MUSA GPUs. + + Args: + self: The Adam optimizer instance. + grad: A tensor representing the sparse gradient values. + var: A resource variable to update. + indices: A tensor representing the indices for sparse update. + apply_state: A dict containing hyperparameter values. + + Returns: + An operation that updates the variable. + """ + from ._loader import get_musa_ops_module + + musa_ops = get_musa_ops_module() + if musa_ops is None or not hasattr(musa_ops, 'musa_resource_sparse_apply_adam'): + # Fallback to original implementation if MUSA op not available + return _original_methods['Adam']['_resource_apply_sparse'](self, grad, var, indices, apply_state) + + var_device, var_dtype = var.device, var.dtype.base_dtype + coefficients = ((apply_state or {}).get((var_device, var_dtype)) + or self._fallback_apply_state(var_device, var_dtype)) + + m = self.get_slot(var, 'm') + v = self.get_slot(var, 'v') + + # Get hyperparameters (same as dense version) + local_step = tf.cast(self.iterations + 1, var_dtype) + beta_1_t = tf.identity(self._get_hyper('beta_1', var_dtype)) + beta_2_t = tf.identity(self._get_hyper('beta_2', var_dtype)) + beta_1_power = tf.pow(beta_1_t, local_step) + beta_2_power = tf.pow(beta_2_t, local_step) + + # Compute bias-corrected learning rate + lr = coefficients['lr_t'] * (tf.sqrt(1 - beta_2_power) / (1 - beta_1_power)) + epsilon = tf.convert_to_tensor(self.epsilon or 1e-7, var_dtype) + + # Call our custom fused kernel via the ops module + return musa_ops.musa_resource_sparse_apply_adam( + var=var.handle, + m=m.handle, + v=v.handle, + beta1_power=beta_1_power, + beta2_power=beta_2_power, + lr=lr, + beta1=beta_1_t, + beta2=beta_2_t, + epsilon=epsilon, + grad=grad, + indices=indices, + use_locking=self._use_locking) + + +def _musa_resource_apply_dense(self, grad, var, apply_state=None): + """Use fused ResourceApplyAdam kernel on MUSA device. + + This method replaces the default TensorFlow Adam's _resource_apply_dense + to use the fused MUSA kernel. + + Args: + self: The Adam optimizer instance. + grad: A tensor representing the dense gradient. + var: A resource variable to update. + apply_state: A dict containing hyperparameter values. + + Returns: + An operation that updates the variable. + """ + var_device, var_dtype = var.device, var.dtype.base_dtype + coefficients = ((apply_state or {}).get((var_device, var_dtype)) + or self._fallback_apply_state(var_device, var_dtype)) + + m = self.get_slot(var, 'm') + v = self.get_slot(var, 'v') + + # Get hyperparameters + local_step = tf.cast(self.iterations + 1, var_dtype) + beta_1_t = tf.identity(self._get_hyper('beta_1', var_dtype)) + beta_2_t = tf.identity(self._get_hyper('beta_2', var_dtype)) + beta_1_power = tf.pow(beta_1_t, local_step) + beta_2_power = tf.pow(beta_2_t, local_step) + + lr = coefficients['lr_t'] * (tf.sqrt(1 - beta_2_power) / (1 - beta_1_power)) + epsilon = tf.convert_to_tensor(self.epsilon or 1e-7, var_dtype) + + # Use the fused ResourceApplyAdam operation + # This dispatches to the MUSA kernel when on MUSA device + return tf.raw_ops.ResourceApplyAdam( + var=var.handle, + m=m.handle, + v=v.handle, + beta1_power=beta_1_power, + beta2_power=beta_2_power, + lr=lr, + beta1=beta_1_t, + beta2=beta_2_t, + epsilon=epsilon, + grad=grad, + use_locking=self._use_locking) + + +def patch_keras_adam(): + """Patch tf.keras.optimizers.Adam to use MUSA kernels. + + After patching: + - _resource_apply_dense uses fused ResourceApplyAdam kernel + - _resource_apply_sparse uses fused MusaResourceSparseApplyAdam kernel + + This provides significant performance improvement for training models + with embedding layers on MUSA GPUs. + """ + global _original_methods + + adam_class = tf.keras.optimizers.Adam + + # Store original methods + _original_methods['Adam'] = { + '_resource_apply_dense': adam_class._resource_apply_dense, + '_resource_apply_sparse': adam_class._resource_apply_sparse, + } + + # Apply patches + adam_class._resource_apply_dense = _musa_resource_apply_dense + adam_class._resource_apply_sparse = _musa_resource_apply_sparse + + # Also patch NonFusedAdam if it exists + try: + non_fused_adam_class = tf.keras.optimizers.NonFusedAdam + _original_methods['NonFusedAdam'] = { + '_resource_apply_dense': non_fused_adam_class._resource_apply_dense, + '_resource_apply_sparse': non_fused_adam_class._resource_apply_sparse, + } + non_fused_adam_class._resource_apply_dense = _musa_resource_apply_dense + non_fused_adam_class._resource_apply_sparse = _musa_resource_apply_sparse + except AttributeError: + pass # NonFusedAdam may not exist in all TF versions + + print("MUSA Adam optimizer patch applied successfully.") + + +def unpatch_keras_adam(): + """Restore original Adam optimizer methods. + + Removes the MUSA patches and restores the default TensorFlow implementation. + """ + global _original_methods + + if 'Adam' in _original_methods: + adam_class = tf.keras.optimizers.Adam + original = _original_methods['Adam'] + if '_resource_apply_dense' in original: + adam_class._resource_apply_dense = original['_resource_apply_dense'] + if '_resource_apply_sparse' in original: + adam_class._resource_apply_sparse = original['_resource_apply_sparse'] + + if 'NonFusedAdam' in _original_methods: + try: + non_fused_adam_class = tf.keras.optimizers.NonFusedAdam + original = _original_methods['NonFusedAdam'] + if '_resource_apply_dense' in original: + non_fused_adam_class._resource_apply_dense = original['_resource_apply_dense'] + if '_resource_apply_sparse' in original: + non_fused_adam_class._resource_apply_sparse = original['_resource_apply_sparse'] + except AttributeError: + pass + + _original_methods.clear() + print("MUSA Adam optimizer patch removed.") + + +def is_adam_patched(): + """Check if Adam optimizer has been patched with MUSA kernels. + + Returns: + bool: True if patched, False otherwise. + """ + return 'Adam' in _original_methods diff --git a/python/optimizer/__init__.py b/python/optimizer/__init__.py new file mode 100644 index 00000000..bb62fe20 --- /dev/null +++ b/python/optimizer/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2026 The TensorFlow MUSA Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +""" +MUSA Optimizer module. + +Provides high-performance optimizer implementations using fused MUSA kernels. +""" + +from .adam import Adam + +__all__ = ["Adam"] diff --git a/python/optimizer/adam.py b/python/optimizer/adam.py new file mode 100644 index 00000000..ee02477a --- /dev/null +++ b/python/optimizer/adam.py @@ -0,0 +1,288 @@ +# Copyright 2026 The TensorFlow MUSA Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +""" +MUSA-accelerated Adam optimizer. + +This optimizer uses the fused ResourceApplyAdam kernel for improved performance +on MUSA GPUs. The fused kernel combines multiple operations (moment updates, +velocity updates, bias correction, and variable updates) into a single kernel, +reducing memory bandwidth and kernel launch overhead. + +Example usage: + import tensorflow as tf + import tensorflow_musa as tf_musa + + # Use MUSA-accelerated Adam optimizer + optimizer = tf_musa.optimizer.Adam(learning_rate=0.001) + + model.compile(optimizer=optimizer, loss='mse') + model.fit(x, y) +""" + +import tensorflow as tf +from tensorflow.python.keras.optimizer_v2 import optimizer_v2 +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import state_ops + + +class Adam(optimizer_v2.OptimizerV2): + """MUSA-accelerated Adam optimizer using fused kernel. + + Implements the Adam algorithm with bias correction as described in + "Adam: A Method for Stochastic Optimization" (Kingma et al., 2015). + + The update rule is: + m_t = beta1 * m_{t-1} + (1 - beta1) * g + v_t = beta2 * v_{t-1} + (1 - beta2) * g^2 + lr_t = lr * sqrt(1 - beta2^t) / (1 - beta1^t) + var = var - lr_t * m_t / (sqrt(v_t) + epsilon) + + This optimizer uses the fused ResourceApplyAdam kernel registered by + the MUSA plugin, providing better performance on MUSA GPUs compared + to the decomposed Adam implementation in TensorFlow. + + Args: + learning_rate: A float, a `LearningRateSchedule` instance, or a callable + that takes no arguments and returns the learning rate. Defaults to 0.001. + beta_1: A float value. The exponential decay rate for the 1st moment + estimates. Defaults to 0.9. + beta_2: A float value. The exponential decay rate for the 2nd moment + estimates. Defaults to 0.999. + epsilon: A small float for numerical stability. Defaults to 1e-7. + amsgrad: Whether to apply AMSGrad variant of Adam. Not currently + supported in fused kernel. Defaults to False. + name: Optional name for the operations created when applying gradients. + Defaults to "AdamMUSA". + **kwargs: Additional keyword arguments. Allowed to be {`clipnorm`, + `clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by + norm; `clipvalue` is clip gradients by value, `decay` is + included for backward compatibility to allow time inverse + decay of learning rate. `lr` is included for backward + compatibility, recommended to use `learning_rate` instead. + """ + + def __init__( + self, + learning_rate=0.001, + beta_1=0.9, + beta_2=0.999, + epsilon=1e-7, + amsgrad=False, + name="AdamMUSA", + **kwargs + ): + """Initialize Adam optimizer.""" + super(Adam, self).__init__(name, **kwargs) + + if amsgrad: + raise NotImplementedError( + "AMSGrad variant is not supported in the fused MUSA Adam kernel. " + "Use tf.keras.optimizers.Adam for AMSGrad support." + ) + + self._set_hyper("learning_rate", kwargs.get("lr", learning_rate)) + self._set_hyper("beta_1", beta_1) + self._set_hyper("beta_2", beta_2) + self._set_hyper("epsilon", epsilon) + + # Beta powers track the iteration for bias correction + self._beta1_power = None + self._beta2_power = None + + def _create_slots(self, var_list): + """Create slot variables for Adam. + + For each variable in var_list, creates two slot variables: + - m: First moment estimates (moving average of gradients) + - v: Second moment estimates (moving average of squared gradients) + """ + for var in var_list: + self.add_slot(var, "m") + self.add_slot(var, "v") + + def _create_hypers(self): + """Create hyper variables.""" + self._beta1_power = self.add_weight( + name="beta1_power", + shape=(), + dtype=tf.float32, + trainable=False, + initializer=tf.constant_initializer(self._beta_1), + ) + self._beta2_power = self.add_weight( + name="beta2_power", + shape=(), + dtype=tf.float32, + trainable=False, + initializer=tf.constant_initializer(self._beta_2), + ) + + def _prepare(self, var_list): + """Prepare hyper values before applying gradients.""" + # Create hyper variables if not already created + if self._beta1_power is None: + self._create_hypers() + + return { + "beta1_power": self._beta1_power, + "beta2_power": self._beta2_power, + "lr": self._prepare_learning_rate(), + "beta1": math_ops.cast(self._beta_1, var_list[0].dtype), + "beta2": math_ops.cast(self._beta_2, var_list[0].dtype), + "epsilon": math_ops.cast(self._epsilon, var_list[0].dtype), + } + + def _resource_apply_dense(self, grad, var, apply_state=None): + """Apply gradient update using fused ResourceApplyAdam kernel. + + This method dispatches to the MUSA-registered ResourceApplyAdam kernel + when the variable is placed on a MUSA device. + + Args: + grad: A tensor representing the gradient. + var: A resource variable to update. + apply_state: A dict containing hyperparameter values. + + Returns: + An operation that updates the variable. + """ + if apply_state is None: + apply_state = self._prepare([var]) + + var_dtype = var.dtype.base_dtype + coefficients = apply_state.get((var_dtype, None), apply_state) + + m = self.get_slot(var, "m") + v = self.get_slot(var, "v") + + # Cast coefficients to match variable dtype + beta1_power = math_ops.cast(coefficients["beta1_power"], var_dtype) + beta2_power = math_ops.cast(coefficients["beta2_power"], var_dtype) + lr = math_ops.cast(coefficients["lr"], var_dtype) + beta1 = math_ops.cast(coefficients["beta1"], var_dtype) + beta2 = math_ops.cast(coefficients["beta2"], var_dtype) + epsilon = math_ops.cast(coefficients["epsilon"], var_dtype) + + # Use the fused ResourceApplyAdam operation + # This dispatches to the MUSA kernel when on MUSA device + return tf.raw_ops.ResourceApplyAdam( + var=var.handle, + m=m.handle, + v=v.handle, + beta1_power=beta1_power, + beta2_power=beta2_power, + lr=lr, + beta1=beta1, + beta2=beta2, + epsilon=epsilon, + grad=grad, + use_locking=self._use_locking, + use_nesterov=False, + ) + + def _resource_apply_sparse(self, grad, var, indices, apply_state=None): + """Apply sparse gradient update using fused MusaResourceSparseApplyAdam kernel. + + This method uses the custom MUSA sparse Adam kernel for efficient + embedding updates on MUSA GPUs. + + Args: + grad: A tensor representing the gradient values. + var: A resource variable to update. + indices: A tensor representing the indices for sparse update. + apply_state: A dict containing hyperparameter values. + + Returns: + An operation that updates the variable. + """ + from .._loader import get_musa_ops_module + + musa_ops = get_musa_ops_module() + + # Check if MUSA sparse kernel is available + if musa_ops is None or not hasattr(musa_ops, 'musa_resource_sparse_apply_adam'): + # Fallback to densifying gradient if sparse kernel not available + dense_grad = tf.IndexedSlices(grad, indices, tf.shape(var)) + dense_grad_tensor = tf.convert_to_tensor(dense_grad) + return self._resource_apply_dense(dense_grad_tensor, var, apply_state) + + if apply_state is None: + apply_state = self._prepare([var]) + + var_dtype = var.dtype.base_dtype + coefficients = apply_state.get((var_dtype, None), apply_state) + + m = self.get_slot(var, "m") + v = self.get_slot(var, "v") + + # Compute bias-corrected learning rate + local_step = math_ops.cast(self.iterations + 1, var_dtype) + beta_1_t = math_ops.cast(self._beta_1, var_dtype) + beta_2_t = math_ops.cast(self._beta_2, var_dtype) + beta_1_power = math_ops.pow(beta_1_t, local_step) + beta_2_power = math_ops.pow(beta_2_t, local_step) + + lr = math_ops.cast(coefficients["lr"], var_dtype) + epsilon = math_ops.cast(self._epsilon, var_dtype) + + # Use the fused MusaResourceSparseApplyAdam kernel + return musa_ops.musa_resource_sparse_apply_adam( + var=var.handle, + m=m.handle, + v=v.handle, + beta1_power=beta_1_power, + beta2_power=beta_2_power, + lr=lr, + beta1=beta_1_t, + beta2=beta_2_t, + epsilon=epsilon, + grad=grad, + indices=indices, + use_locking=self._use_locking) + + def _finish(self, update_ops, var_list): + """Finish the update by updating beta powers.""" + if self._beta1_power is None: + return update_ops + + # Update beta powers for next iteration + beta1_update = state_ops.assign( + self._beta1_power, + self._beta1_power * self._beta_1, + use_locking=self._use_locking, + ) + beta2_update = state_ops.assign( + self._beta2_power, + self._beta2_power * self._beta_2, + use_locking=self._use_locking, + ) + + return update_ops + [beta1_update, beta2_update] + + def get_config(self): + """Get optimizer configuration.""" + config = super(Adam, self).get_config() + config.update({ + "learning_rate": self._serialize_hyperparameter("learning_rate"), + "beta_1": self._serialize_hyperparameter("beta_1"), + "beta_2": self._serialize_hyperparameter("beta_2"), + "epsilon": self._epsilon, + }) + return config + + def from_config(cls, config): + """Create optimizer from configuration.""" + return cls(**config) diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..dd1b1111 --- /dev/null +++ b/setup.py @@ -0,0 +1,218 @@ +# Copyright 2026 The TensorFlow MUSA Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Setup script for tensorflow_musa package.""" + +import os +import shutil +import subprocess +import sys +from setuptools import setup, Command +from wheel.bdist_wheel import bdist_wheel + + +# Package metadata +PACKAGE_NAME = "tensorflow_musa" # pip install name +SOURCE_DIR = "python" # source code directory +VERSION = "0.1.0" +DESCRIPTION = "High-performance TensorFlow extension for Moore Threads MUSA GPUs" +AUTHOR = "TensorFlow MUSA Authors" +LICENSE = "Apache 2.0" + +# Build configuration +PLUGIN_LIBRARY = "libmusa_plugin.so" +BUILD_DIR = "build" + +# Required TensorFlow version +REQUIRED_TF_VERSION = "2.6.1" + + +def check_tensorflow_version(): + """Check if TensorFlow is installed with the required version. + + Returns: + tuple: (is_installed, version_string or None) + + Raises: + SystemExit: If TensorFlow is installed but version doesn't match. + """ + try: + import tensorflow as tf + version = tf.__version__ + + if version != REQUIRED_TF_VERSION: + print(f"ERROR: TensorFlow version mismatch!") + print(f" Required: {REQUIRED_TF_VERSION}") + print(f" Installed: {version}") + print(f" Please install the correct version: pip install tensorflow=={REQUIRED_TF_VERSION}") + sys.exit(1) + + print(f"TensorFlow {version} found - OK") + return True, version + except ImportError: + print(f"WARNING: TensorFlow not installed.") + print(f" Required version: {REQUIRED_TF_VERSION}") + print(f" Please install: pip install tensorflow=={REQUIRED_TF_VERSION}") + return False, None + + +class BuildPluginCommand(Command): + """Build the MUSA plugin shared library using CMake.""" + + user_options = [] + + def initialize_options(self): + pass + + def finalize_options(self): + pass + + def run(self): + # Check TensorFlow version before building + check_tensorflow_version() + + project_root = os.path.abspath(os.path.dirname(__file__)) + build_dir = os.path.join(project_root, BUILD_DIR) + + # Create build directory if it doesn't exist + if not os.path.exists(build_dir): + os.makedirs(build_dir) + + # Run CMake configuration + cmake_cmd = [ + "cmake", + "..", + "-DCMAKE_BUILD_TYPE=Release", + "-DMUSA_KERNEL_DEBUG=OFF", + ] + + print(f"Running CMake configuration: {cmake_cmd}") + result = subprocess.run(cmake_cmd, cwd=build_dir, check=False) + if result.returncode != 0: + print("CMake configuration failed. Please ensure MUSA SDK and TensorFlow are installed.") + sys.exit(1) + + # Run make to build the library + make_cmd = ["make", f"-j{os.cpu_count()}"] + print(f"Running make: {make_cmd}") + result = subprocess.run(make_cmd, cwd=build_dir, check=False) + if result.returncode != 0: + print("Make failed.") + sys.exit(1) + + # Verify the library was built + plugin_path = os.path.join(build_dir, PLUGIN_LIBRARY) + if not os.path.exists(plugin_path): + print(f"Error: {PLUGIN_LIBRARY} not found after build.") + sys.exit(1) + + # Copy to package directory (source dir is python, but package name is tensorflow_musa) + package_lib_path = os.path.join(project_root, SOURCE_DIR, PLUGIN_LIBRARY) + shutil.copy2(plugin_path, package_lib_path) + print(f"Successfully built and copied to: {package_lib_path}") + + +class BdistWheelCommand(bdist_wheel): + """Custom bdist_wheel that builds plugin and excludes test directory.""" + + def run(self): + # Check TensorFlow version first + check_tensorflow_version() + + # Always rebuild the plugin for wheel packaging so the wheel + # contains a library matching the current source tree. + project_root = os.path.abspath(os.path.dirname(__file__)) + BuildPluginCommand(self.distribution).run() + + # Force only tensorflow_musa packages (source is in python directory) + self.distribution.packages = ["tensorflow_musa", "tensorflow_musa.optimizer"] + self.distribution.package_data = {PACKAGE_NAME: [PLUGIN_LIBRARY]} + self.distribution.py_modules = None + # Map tensorflow_musa package name to python source directory + self.distribution.package_dir = {"tensorflow_musa": SOURCE_DIR} + + # Clean build/lib to only contain tensorflow_musa + build_lib = os.path.join(project_root, "build", "lib") + if os.path.exists(build_lib): + # Remove test directory from build/lib + test_dir = os.path.join(build_lib, "test") + if os.path.exists(test_dir): + shutil.rmtree(test_dir) + # Remove musa_ext directory + musa_ext_dir = os.path.join(build_lib, "musa_ext") + if os.path.exists(musa_ext_dir): + shutil.rmtree(musa_ext_dir) + # Remove docs directory + docs_dir = os.path.join(build_lib, "docs") + if os.path.exists(docs_dir): + shutil.rmtree(docs_dir) + + super().run() + + +# Read long description from README +def get_long_description(): + readme_path = os.path.join(os.path.dirname(__file__), "README.md") + if os.path.exists(readme_path): + with open(readme_path, "r", encoding="utf-8") as f: + return f.read() + return DESCRIPTION + + +# Check TensorFlow at setup.py load time (before any build commands) +# This ensures version mismatch is detected early +check_tensorflow_version() + + +setup( + name=PACKAGE_NAME, + version=VERSION, + description=DESCRIPTION, + long_description=get_long_description(), + long_description_content_type="text/markdown", + author=AUTHOR, + license=LICENSE, + # Map package name (tensorflow_musa) to source directory (python) + package_dir={"tensorflow_musa": SOURCE_DIR}, + # Package names (pip install tensorflow_musa) + packages=["tensorflow_musa", "tensorflow_musa.optimizer"], + package_data={ + PACKAGE_NAME: [PLUGIN_LIBRARY], + }, + python_requires=">=3.7", + # NOTE: tensorflow is NOT listed in install_requires to prevent pip from + # downloading it during wheel build. Users must install tensorflow==2.6.1 + # manually before installing tensorflow_musa. + # See README.md for installation instructions. + install_requires=[ + "numpy>=1.19.0", + ], + cmdclass={ + "bdist_wheel": BdistWheelCommand, + "build_plugin": BuildPluginCommand, + }, + classifiers=[ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + ], + keywords="tensorflow musa gpu moore-threads deep-learning", +) diff --git a/test/ops/musa_resource_sparse_apply_adam_op_test.py b/test/ops/musa_resource_sparse_apply_adam_op_test.py new file mode 100644 index 00000000..9af3c54d --- /dev/null +++ b/test/ops/musa_resource_sparse_apply_adam_op_test.py @@ -0,0 +1,344 @@ +# Copyright 2026 The TensorFlow MUSA Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for MUSA ResourceSparseApplyAdam operator.""" + +import os +import sys + +# Add the test directory to path for importing musa_test_utils +test_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.insert(0, test_dir) + +import numpy as np +import tensorflow as tf +from musa_test_utils import MUSATestCase + + +class MusaResourceSparseApplyAdamTest(MUSATestCase): + """Tests for MUSA fused sparse Adam operator.""" + + @classmethod + def setUpClass(cls): + """Load the custom op module via tf.load_op_library.""" + super(MusaResourceSparseApplyAdamTest, cls).setUpClass() + + plugin_path = None + current_dir = os.path.dirname(os.path.abspath(__file__)) + + candidate_paths = [ + os.path.join(current_dir, "..", "..", "build", "libmusa_plugin.so"), + os.path.join(os.path.dirname(current_dir), "..", "build", "libmusa_plugin.so"), + os.path.join(os.getcwd(), "build", "libmusa_plugin.so"), + ] + + for path in candidate_paths: + normalized_path = os.path.normpath(path) + if os.path.exists(normalized_path): + plugin_path = normalized_path + break + + if plugin_path and os.path.exists(plugin_path): + try: + cls._musa_ops = tf.load_op_library(plugin_path) + except Exception as e: + print(f"FAILED: Error loading MUSA ops from {plugin_path}: {e}") + cls._musa_ops = None + else: + searched = [os.path.normpath(p) for p in candidate_paths] + print("MUSA plugin not found. Searched:\n" + "\n".join(f" - {loc}" for loc in searched)) + cls._musa_ops = None + + if cls._musa_ops is not None and not hasattr(cls._musa_ops, 'musa_resource_sparse_apply_adam'): + print("MUSA ops module loaded but musa_resource_sparse_apply_adam not found") + print("Available ops:", [op for op in dir(cls._musa_ops) if not op.startswith('_')]) + + def _compute_expected(self, var_np, m_np, v_np, lr_np, beta1_np, beta2_np, + epsilon_np, beta1_power_np, beta2_power_np, + grad_np, indices_np): + """Compute expected Adam update using NumPy.""" + expected_var = var_np.astype(np.float32).copy() + expected_m = m_np.astype(np.float32).copy() + expected_v = v_np.astype(np.float32).copy() + + # Compute bias-corrected learning rate + one_minus_beta1_power = 1.0 - beta1_power_np + one_minus_beta2_power = 1.0 - beta2_power_np + + if abs(one_minus_beta1_power) < 1e-10: + lr_t = lr_np # Initial iteration fallback + else: + lr_t = lr_np * np.sqrt(one_minus_beta2_power) / one_minus_beta1_power + + one_minus_beta1 = 1.0 - beta1_np + one_minus_beta2 = 1.0 - beta2_np + + for i, idx in enumerate(indices_np): + if idx < 0: + continue + + g = grad_np[i].astype(np.float32) + + # m_t = beta1 * m + (1 - beta1) * g + expected_m[idx] = beta1_np * expected_m[idx] + one_minus_beta1 * g + + # v_t = beta2 * v + (1 - beta2) * g^2 + expected_v[idx] = beta2_np * expected_v[idx] + one_minus_beta2 * g * g + + # var = var - lr_t * m_t / (sqrt(v_t) + epsilon) + v_sqrt = np.sqrt(expected_v[idx]) + expected_var[idx] = expected_var[idx] - lr_t * expected_m[idx] / (v_sqrt + epsilon_np) + + return expected_var, expected_m, expected_v + + def _test_logic(self, var_np, m_np, v_np, lr_np, beta1_np, beta2_np, + epsilon_np, beta1_power_np, beta2_power_np, + grad_np, indices_np, dtype, index_dtype): + """Test sparse Adam update on MUSA device.""" + if self._musa_ops is None or not hasattr(self._musa_ops, 'musa_resource_sparse_apply_adam'): + self.skipTest("MUSA sparse Adam op not available") + + expected_var, expected_m, expected_v = self._compute_expected( + var_np, m_np, v_np, lr_np, beta1_np, beta2_np, + epsilon_np, beta1_power_np, beta2_power_np, + grad_np, indices_np) + + with tf.device("/device:MUSA:0"): + var = tf.Variable(var_np, dtype=dtype) + m = tf.Variable(m_np, dtype=dtype) + v = tf.Variable(v_np, dtype=dtype) + + # Check if variables are actually placed on MUSA device + # bfloat16 may not be supported on MUSA, causing fallback to CPU + if 'MUSA' not in var.device: + self.skipTest(f"{dtype} variables not supported on MUSA device (placed on {var.device})") + + # For bfloat16/half, use float32 values and cast to ensure device placement + lr = tf.cast(tf.constant(lr_np, dtype=tf.float32), dtype) + beta1 = tf.cast(tf.constant(beta1_np, dtype=tf.float32), dtype) + beta2 = tf.cast(tf.constant(beta2_np, dtype=tf.float32), dtype) + epsilon = tf.cast(tf.constant(epsilon_np, dtype=tf.float32), dtype) + beta1_power = tf.cast(tf.constant(beta1_power_np, dtype=tf.float32), dtype) + beta2_power = tf.cast(tf.constant(beta2_power_np, dtype=tf.float32), dtype) + + grad = tf.constant(grad_np, dtype=dtype) + indices = tf.constant(indices_np, dtype=index_dtype) + + # Call the custom op via the ops module + self._musa_ops.musa_resource_sparse_apply_adam( + var=var.handle, + m=m.handle, + v=v.handle, + beta1_power=beta1_power, + beta2_power=beta2_power, + lr=lr, + beta1=beta1, + beta2=beta2, + epsilon=epsilon, + grad=grad, + indices=indices, + use_locking=False) + + out_var = var.read_value().numpy() + out_m = m.read_value().numpy() + out_v = v.read_value().numpy() + + # Using higher tolerance for half precision + if dtype in [tf.float16, tf.bfloat16]: + self.assertAllClose(expected_var, out_var, atol=1e-2, rtol=1e-2) + self.assertAllClose(expected_m, out_m, atol=1e-2, rtol=1e-2) + self.assertAllClose(expected_v, out_v, atol=1e-2, rtol=1e-2) + else: + self.assertAllClose(expected_var, out_var) + self.assertAllClose(expected_m, out_m) + self.assertAllClose(expected_v, out_v) + + def testBasic(self): + """Test basic sparse Adam update.""" + # Note: bfloat16 is tested separately in testBasicBFloat16 as it may not + # be supported on all MUSA devices + for dtype in [tf.float32, tf.float16]: + for index_dtype in [np.int32, np.int64]: + # Simple 2-row, 2-column embedding + var_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32) + m_np = np.array([[0.0, 0.0], [0.0, 0.0]], dtype=np.float32) + v_np = np.array([[0.0, 0.0], [0.0, 0.0]], dtype=np.float32) + + lr_np = 0.001 + beta1_np = 0.9 + beta2_np = 0.999 + epsilon_np = 1e-7 + beta1_power_np = 0.9 # First iteration + beta2_power_np = 0.999 + + # Update only row 1 + indices_np = np.array([1], dtype=index_dtype) + grad_np = np.array([[0.1, 0.2]], dtype=np.float32) + + self._test_logic( + var_np, m_np, v_np, lr_np, beta1_np, beta2_np, + epsilon_np, beta1_power_np, beta2_power_np, + grad_np, indices_np, dtype, index_dtype) + + def testBasicBFloat16(self): + """Test bfloat16 sparse Adam update (may skip if device doesn't support bfloat16).""" + for index_dtype in [np.int32, np.int64]: + var_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32) + m_np = np.array([[0.0, 0.0], [0.0, 0.0]], dtype=np.float32) + v_np = np.array([[0.0, 0.0], [0.0, 0.0]], dtype=np.float32) + + lr_np = 0.001 + beta1_np = 0.9 + beta2_np = 0.999 + epsilon_np = 1e-7 + beta1_power_np = 0.9 + beta2_power_np = 0.999 + + indices_np = np.array([1], dtype=index_dtype) + grad_np = np.array([[0.1, 0.2]], dtype=np.float32) + + self._test_logic( + var_np, m_np, v_np, lr_np, beta1_np, beta2_np, + epsilon_np, beta1_power_np, beta2_power_np, + grad_np, indices_np, tf.bfloat16, index_dtype) + + def testMultipleIndices(self): + """Test updating multiple rows.""" + dtype = tf.float32 + + # 5-row, 4-column embedding + rows = 5 + cols = 4 + var_np = np.random.random([rows, cols]).astype(np.float32) + m_np = np.zeros([rows, cols], dtype=np.float32) + v_np = np.zeros([rows, cols], dtype=np.float32) + + lr_np = 0.001 + beta1_np = 0.9 + beta2_np = 0.999 + epsilon_np = 1e-7 + beta1_power_np = 0.81 # Second iteration (0.9^2) + beta2_power_np = 0.998001 # Second iteration (0.999^2) + + # Update rows 0, 2, 4 + indices_np = np.array([0, 2, 4], dtype=np.int32) + grad_np = np.random.random([3, cols]).astype(np.float32) + + self._test_logic( + var_np, m_np, v_np, lr_np, beta1_np, beta2_np, + epsilon_np, beta1_power_np, beta2_power_np, + grad_np, indices_np, dtype, np.int32) + + def testEmbeddingScenario(self): + """Test with larger embedding-like dimensions.""" + for dtype in [tf.float32, tf.float16]: + # Simulate word embedding: 10000 words, 128 dimensions + vocab_size = 1000 + embedding_dim = 64 + var_np = np.random.random([vocab_size, embedding_dim]).astype(np.float32) + m_np = np.zeros([vocab_size, embedding_dim], dtype=np.float32) + v_np = np.zeros([vocab_size, embedding_dim], dtype=np.float32) + + lr_np = 0.001 + beta1_np = 0.9 + beta2_np = 0.999 + epsilon_np = 1e-7 + beta1_power_np = 0.9 + beta2_power_np = 0.999 + + # Batch of 32 tokens + batch_size = 32 + indices_np = np.random.choice(vocab_size, batch_size, replace=False).astype(np.int32) + grad_np = np.random.random([batch_size, embedding_dim]).astype(np.float32) + + self._test_logic( + var_np, m_np, v_np, lr_np, beta1_np, beta2_np, + epsilon_np, beta1_power_np, beta2_power_np, + grad_np, indices_np, dtype, np.int32) + + def testEmptyIndices(self): + """Test with empty indices (no-op).""" + dtype = tf.float32 + var_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32) + m_np = np.array([[0.1, 0.1], [0.1, 0.1]], dtype=np.float32) + v_np = np.array([[0.1, 0.1], [0.1, 0.1]], dtype=np.float32) + + lr_np = 0.001 + beta1_np = 0.9 + beta2_np = 0.999 + epsilon_np = 1e-7 + beta1_power_np = 0.9 + beta2_power_np = 0.999 + + indices_np = np.array([], dtype=np.int32) + grad_np = np.zeros([0, 2], dtype=np.float32) + + # Should not change anything + self._test_logic( + var_np, m_np, v_np, lr_np, beta1_np, beta2_np, + epsilon_np, beta1_power_np, beta2_power_np, + grad_np, indices_np, dtype, np.int32) + + def testLargeRowsInt64Indices(self): + """Test large rows with int64 indices.""" + dtype = tf.float32 + + rows = 500 + cols = 32 + var_np = np.random.random([rows, cols]).astype(np.float32) + m_np = np.zeros([rows, cols], dtype=np.float32) + v_np = np.zeros([rows, cols], dtype=np.float32) + + lr_np = 0.001 + beta1_np = 0.9 + beta2_np = 0.999 + epsilon_np = 1e-7 + beta1_power_np = 0.9 + beta2_power_np = 0.999 + + # Update first and last rows using int64 indices + indices_np = np.array([0, rows - 1], dtype=np.int64) + grad_np = np.random.random([2, cols]).astype(np.float32) + + self._test_logic( + var_np, m_np, v_np, lr_np, beta1_np, beta2_np, + epsilon_np, beta1_power_np, beta2_power_np, + grad_np, indices_np, dtype, np.int64) + + def testInitialIterationEdgeCase(self): + """Test edge case when beta1_power ≈ 1.0 (initial iteration).""" + dtype = tf.float32 + var_np = np.array([[1.0]], dtype=np.float32) + m_np = np.array([[0.0]], dtype=np.float32) + v_np = np.array([[0.0]], dtype=np.float32) + + lr_np = 0.001 + beta1_np = 0.9 + beta2_np = 0.999 + epsilon_np = 1e-7 + beta1_power_np = 1.0 # Initial iteration edge case + beta2_power_np = 1.0 + + indices_np = np.array([0], dtype=np.int32) + grad_np = np.array([[0.5]], dtype=np.float32) + + self._test_logic( + var_np, m_np, v_np, lr_np, beta1_np, beta2_np, + epsilon_np, beta1_power_np, beta2_power_np, + grad_np, indices_np, dtype, np.int32) + + +if __name__ == "__main__": + tf.test.main()