Skip to content

Commit 019a215

Browse files
committed
address Copilot comments
1 parent 02130b8 commit 019a215

File tree

7 files changed

+9
-21
lines changed

7 files changed

+9
-21
lines changed

jax-inference-offloading/examples/download_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,8 @@ def main():
150150
key = args.kaggle_key or os.getenv("KAGGLE_KEY")
151151
with stdout_to_stderr():
152152
return download_kaggle(args.model, args.flavor, username, key)
153+
else:
154+
raise ValueError(f"Unknown hub: {hub}")
153155

154156

155157
if __name__ == "__main__":

jax-inference-offloading/jax_inference_offloading/transport/model/nccl_grouped.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ def __init__(
7171
def __call__(self, named_parameters: Dict[str, jax.Array]):
7272
mapping_specs = self._mapping_specs
7373
gateway = self._gateway
74-
transports = self._transports
7574
transport_config = self._transport_config
7675

7776
gateway.start_weight_transfer('grouped')

jax-inference-offloading/jax_inference_offloading/transport/model/nccl_unfused.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
from jax_inference_offloading.controller.trainer_client import TrainerClient
3131
from jax_inference_offloading.sharding import PolymorphicMesh
3232
from jax_inference_offloading.transport.tensor.nccl_star import NcclStarTransport
33-
from jax_inference_offloading.models.mapping_util import _proto_to_slice
3433
from jax_inference_offloading.timer import Timer
3534

3635
logger = logging.getLogger(__name__)

jax-inference-offloading/jax_inference_offloading/transport/tensor/nccl_base.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,6 @@
3030
except ImportError:
3131
has_torch = False
3232

33-
try:
34-
import jax
35-
import jax.numpy as jnp
36-
has_jax = True
37-
except ImportError:
38-
has_jax = False
39-
4033

4134
logger = getLogger(__name__)
4235

jax-inference-offloading/jax_inference_offloading/transport/tensor/nccl_star.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,6 @@
3131
except ImportError:
3232
has_torch = False
3333

34-
try:
35-
import jax
36-
import jax.numpy as jnp
37-
has_jax = True
38-
except ImportError:
39-
has_jax = False
4034

4135
logger = getLogger(__name__)
4236

@@ -120,6 +114,8 @@ def create_trainer_transport(cls, config: dict[str, Any]) -> List['NcclStarTrans
120114
comm = nccl.NcclCommunicator(world_size, unique_id, 0) # trainer rank is at the center of the star in fan-out mode and is always rank 0
121115
transports.append(cls(comm=comm))
122116
return transports
117+
else:
118+
raise ValueError(f"Unknown transport mode: {config['MODE']}")
123119

124120
@classmethod
125121
def create_rollout_transport(cls, config: dict[str, Any], tp_rank: int) -> 'NcclStarTransport':
@@ -141,6 +137,8 @@ def create_rollout_transport(cls, config: dict[str, Any], tp_rank: int) -> 'Nccl
141137
unique_id = cls.decode_nccl_id(config["UNIQUE_IDS"][star_id])
142138
comm = nccl.NcclCommunicator(world_size, unique_id, rank)
143139
return cls(comm)
140+
else:
141+
raise ValueError(f"Unknown transport mode: {config['MODE']}")
144142

145143
def __init__(self, comm: nccl.NcclCommunicator):
146144
self._comm = comm
@@ -272,7 +270,6 @@ def scatter_grouped(self, all_buffers: List[List[Any]]) -> None:
272270
assert self._comm.rank_id() == 0, \
273271
"Star scatter must originate from the root (rank 0)."
274272

275-
num_peers = self._comm.size() - 1
276273
with cuda.Device(self._comm.device_id()):
277274
stream = cuda.get_current_stream().ptr
278275

jax-inference-offloading/jax_inference_offloading/tunix/load_model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from dataclasses import asdict, replace
1919

2020
import jax
21-
from flax import nnx
2221

2322
from tunix.models.dummy_model_creator import create_dummy_model
2423

jax-inference-offloading/jax_inference_offloading/vllm/extension.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -164,12 +164,12 @@ def update_weights(self, mapping_specs: TpModelMappingSpecs):
164164
if sharding_specs.parallelism > 0:
165165
shape[sharding_specs.dim] //= sharding_specs.parallelism
166166

167-
# logger.warning(f'vLLM TP rank {tp_rank} receiving {param.vllm_param.name} ...')
167+
logger.debug(f'vLLM TP rank {tp_rank} receiving {param.vllm_param.name} ...')
168168
weight = self.transport.gather(
169169
shape, param.vllm_param.dtype or 'bfloat16',
170170
sharding_specs.aux_dim, sharding_specs.aux_parallelism
171171
)
172-
# logger.warning(f'vLLM TP rank {tp_rank} received {param.vllm_param.name} shape {weight.shape}')
172+
logger.debug(f'vLLM TP rank {tp_rank} received {param.vllm_param.name} shape {weight.shape}')
173173
self._staged_weights.append((param.vllm_param.name, weight))
174174

175175
# TODO: make it optional
@@ -183,8 +183,7 @@ def update_weights(self, mapping_specs: TpModelMappingSpecs):
183183
if sharding_specs.parallelism > 0:
184184
shape[sharding_specs.dim] //= sharding_specs.parallelism
185185

186-
raw_specs_str = ' '.join(str(sharding_specs).split('\n'))
187-
logger.info(f"vLLM expecting: {param.vllm_param.name} shape {shape.tolist()} raw specs {param}")
186+
logger.debug(f"vLLM expecting: {param.vllm_param.name} shape {shape.tolist()} raw specs {param}")
188187

189188
weight = self.transport.recv(shape, param.vllm_param.dtype or 'bfloat16')
190189
self._staged_weights.append((param.vllm_param.name, weight))

0 commit comments

Comments
 (0)