diff --git a/map_embeddings.py b/map_embeddings.py index 4882504..509b11a 100644 --- a/map_embeddings.py +++ b/map_embeddings.py @@ -61,6 +61,7 @@ def main(): parser.add_argument('--encoding', default='utf-8', help='the character encoding for input/output (defaults to utf-8)') parser.add_argument('--precision', choices=['fp16', 'fp32', 'fp64'], default='fp32', help='the floating-point precision (defaults to fp32)') parser.add_argument('--cuda', action='store_true', help='use cuda (requires cupy)') + parser.add_argument('--device_id', default=0, type=int, help='device ID used for the --cuda option (defaults to 0)') parser.add_argument('--batch_size', default=10000, type=int, help='batch size (defaults to 10000); does not affect results, larger is usually faster but uses more memory') parser.add_argument('--seed', type=int, default=0, help='the random seed (defaults to 0)') @@ -153,6 +154,7 @@ def main(): print('ERROR: Install CuPy for CUDA support', file=sys.stderr) sys.exit(-1) xp = get_cupy() + xp.cuda.Device(args.device_id).use() x = xp.asarray(x) z = xp.asarray(z) else: