Skip to content
This repository was archived by the owner on Apr 24, 2025. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions csrc/search/rpc.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef RPC_HPP
#define RPC_HPP

#include <cstdint>
#include <string>
#include <vector>
#include <unordered_map>
Expand Down
8 changes: 6 additions & 2 deletions realhf/base/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@


def get_user_tmp():
user = getpass.getuser()
user_tmp = os.path.join("/home", user, ".cache", "realhf")
home_dir = os.environ.get('HOME', '')
if not home_dir:
user = getpass.getuser()
user_tmp = os.path.join("/home", user, ".cache", "realhf")
else:
user_tmp = os.path.join(home_dir, ".cache", "realhf")
os.makedirs(user_tmp, exist_ok=True)
return user_tmp

Expand Down
30 changes: 9 additions & 21 deletions realhf/impl/model/backend/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,15 @@
import torch.distributed as dist
import transformers

try:
from megatron.core import parallel_state
from megatron.core.distributed.distributed_data_parallel import (
DistributedDataParallel,
)
from megatron.core.distributed.param_and_grad_buffer import ParamAndGradBuffer
from megatron.core.optimizer import DistributedOptimizer, get_megatron_optimizer
from megatron.core.optimizer.clip_grads import clip_grad_norm_fp32, count_zeros_fp32
from megatron.core.optimizer.optimizer_config import OptimizerConfig
from megatron.core.transformer.transformer_config import TransformerConfig
except (ModuleNotFoundError, ImportError):
# importing megatron.core in CPU container will fail due to the requirement of apex
# Here class types must be defined for type hinting
class TransformerConfig:
pass

class DistributedDataParallel:
pass

class DistributedOptimizer:
pass
from megatron.core import parallel_state
from megatron.core.distributed.distributed_data_parallel import (
DistributedDataParallel,
)
from megatron.core.distributed.param_and_grad_buffer import ParamAndGradBuffer
from megatron.core.optimizer import DistributedOptimizer, get_megatron_optimizer
from megatron.core.optimizer.clip_grads import clip_grad_norm_fp32, count_zeros_fp32
from megatron.core.optimizer.optimizer_config import OptimizerConfig
from megatron.core.transformer.transformer_config import TransformerConfig


from realhf.api.core import model_api
Expand Down
13 changes: 5 additions & 8 deletions realhf/impl/model/modules/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,11 @@
from .mlp import LayerNormQKVLinear
from .rotary import RotaryEmbedding

try:
from flash_attn import (
flash_attn_func,
flash_attn_varlen_func,
flash_attn_with_kvcache,
)
except ModuleNotFoundError:
pass
from flash_attn import (
flash_attn_func,
flash_attn_varlen_func,
flash_attn_with_kvcache,
)

logger = logging.getLogger("Attention")

Expand Down