Skip to content

Commit fc77536

Browse files
committed
llama : switch to floating-point token positions
ggml-ci
1 parent 15499eb commit fc77536

File tree

14 files changed

+68
-61
lines changed

14 files changed

+68
-61
lines changed

examples/baby-llama/baby-llama.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1015,9 +1015,9 @@ static struct ggml_tensor * forward_lora(
10151015
struct ggml_tensor * kc = kv_self.k;
10161016
struct ggml_tensor * vc = kv_self.v;
10171017

1018-
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
1018+
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, N);
10191019
{
1020-
int * data = (int *) KQ_pos->data;
1020+
float * data = (float *) KQ_pos->data;
10211021
for (int i = 0; i < N; ++i) {
10221022
data[i] = n_past + i;
10231023
}

examples/finetune/finetune.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -554,7 +554,7 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
554554
};
555555

556556
// KQ_pos - contains the positions
557-
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, N);
557+
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, N);
558558
ggml_set_input(KQ_pos);
559559

560560
// rope has so much parameters that we make a custom function for it
@@ -743,7 +743,7 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
743743

744744
// set KQ_pos
745745
{
746-
int * data = (int *) KQ_pos->data;
746+
float * data = (float *) KQ_pos->data;
747747
for (int i = 0; i < N; ++i) {
748748
data[i] = n_past + i;
749749
}

examples/llava/llava.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_
338338
if (n_eval > n_batch) {
339339
n_eval = n_batch;
340340
}
341-
llama_batch batch = {int32_t(n_eval), nullptr, (image_embed->embed+i*n_embd), nullptr, nullptr, nullptr, nullptr, *n_past, 1, 0, };
341+
llama_batch batch = {int32_t(n_eval), nullptr, (image_embed->embed+i*n_embd), nullptr, nullptr, nullptr, nullptr, (float) *n_past, 1, 0, };
342342
if (llama_decode(ctx_llama, batch)) {
343343
fprintf(stderr, "%s : failed to eval\n", __func__);
344344
return false;

examples/server/server.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1281,7 +1281,7 @@ struct llama_server_context
12811281
}
12821282

12831283
const int n_embd = llama_n_embd(model);
1284-
llama_batch batch_img = { n_eval, nullptr, (img.image_embedding + i * n_embd), nullptr, nullptr, nullptr, nullptr, slot.n_past, 1, 0, };
1284+
llama_batch batch_img = { n_eval, nullptr, (img.image_embedding + i * n_embd), nullptr, nullptr, nullptr, nullptr, (float) slot.n_past, 1, 0, };
12851285
if (llama_decode(ctx, batch_img))
12861286
{
12871287
LOG_TEE("%s : failed to eval image\n", __func__);

examples/train-text-from-scratch/train-text-from-scratch.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ static struct ggml_tensor * llama_build_train_graphs(
291291
};
292292

293293
// KQ_pos - contains the positions
294-
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, N);
294+
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, N);
295295
ggml_set_input(KQ_pos);
296296

297297
// rope has so much parameters that we make a custom function for it
@@ -419,7 +419,7 @@ static struct ggml_tensor * llama_build_train_graphs(
419419
ggml_gallocr_alloc_graph(alloc, gb);
420420

421421
if (!measure_only) {
422-
int * data = (int *) KQ_pos->data;
422+
float * data = (float *) KQ_pos->data;
423423
for (int i = 0; i < N; ++i) {
424424
data[i] = n_past + i;
425425
}

ggml-cuda.cu

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6040,7 +6040,7 @@ static __device__ void rope_yarn(
60406040
// rope == RoPE == rotary positional embedding
60416041
template<typename T, bool has_pos>
60426042
static __global__ void rope(
6043-
const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
6043+
const T * x, T * dst, int ncols, const float * pos, float freq_scale, int p_delta_rows, float freq_base,
60446044
float ext_factor, float attn_factor, rope_corr_dims corr_dims
60456045
) {
60466046
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
@@ -6053,7 +6053,7 @@ static __global__ void rope(
60536053
const int i = row*ncols + col;
60546054
const int i2 = row/p_delta_rows;
60556055

6056-
const int p = has_pos ? pos[i2] : 0;
6056+
const float p = has_pos ? pos[i2] : 0.0f;
60576057
const float theta_base = p*powf(freq_base, -float(col)/ncols);
60586058

60596059
float cos_theta, sin_theta;
@@ -6068,7 +6068,7 @@ static __global__ void rope(
60686068

60696069
template<typename T, bool has_pos>
60706070
static __global__ void rope_neox(
6071-
const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
6071+
const T * x, T * dst, int ncols, int n_dims, const float * pos, float freq_scale, int p_delta_rows,
60726072
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims
60736073
) {
60746074
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
@@ -6095,7 +6095,7 @@ static __global__ void rope_neox(
60956095

60966096
float cur_rot = inv_ndims * ic - ib;
60976097

6098-
const int p = has_pos ? pos[i2] : 0;
6098+
const float p = has_pos ? pos[i2] : 0.0f;
60996099
const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f);
61006100

61016101
float cos_theta, sin_theta;
@@ -6109,7 +6109,7 @@ static __global__ void rope_neox(
61096109
}
61106110

61116111
static __global__ void rope_glm_f32(
6112-
const float * x, float * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
6112+
const float * x, float * dst, int ncols, const float * pos, float freq_scale, int p_delta_rows, float freq_base,
61136113
int n_ctx
61146114
) {
61156115
const int col = blockDim.x*blockIdx.x + threadIdx.x;
@@ -6124,10 +6124,10 @@ static __global__ void rope_glm_f32(
61246124
const int i2 = row/p_delta_rows;
61256125

61266126
const float col_theta_scale = powf(freq_base, -2.0f*col/ncols);
6127-
// FIXME: this is likely wrong
6128-
const int p = pos != nullptr ? pos[i2] : 0;
61296127

6130-
const float theta = min(p, n_ctx - 2)*freq_scale*col_theta_scale;
6128+
const float p = pos != nullptr ? pos[i2] : 0.0f;
6129+
6130+
const float theta = min(p, (float) n_ctx - 2)*freq_scale*col_theta_scale;
61316131
const float sin_theta = sinf(theta);
61326132
const float cos_theta = cosf(theta);
61336133

@@ -6137,7 +6137,7 @@ static __global__ void rope_glm_f32(
61376137
dst[i + 0] = x0*cos_theta - x1*sin_theta;
61386138
dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta;
61396139

6140-
const float block_theta = ((float)max(p - n_ctx - 2, 0))*col_theta_scale;
6140+
const float block_theta = max(p - n_ctx - 2, 0.0f)*col_theta_scale;
61416141
const float sin_block_theta = sinf(block_theta);
61426142
const float cos_block_theta = cosf(block_theta);
61436143

@@ -7688,7 +7688,7 @@ static void clamp_f32_cuda(const float * x, float * dst, const float min, const
76887688

76897689
template<typename T>
76907690
static void rope_cuda(
7691-
const T * x, T * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
7691+
const T * x, T * dst, int ncols, int nrows, const float * pos, float freq_scale, int p_delta_rows,
76927692
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
76937693
) {
76947694
GGML_ASSERT(ncols % 2 == 0);
@@ -7708,7 +7708,7 @@ static void rope_cuda(
77087708

77097709
template<typename T>
77107710
static void rope_neox_cuda(
7711-
const T * x, T * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
7711+
const T * x, T * dst, int ncols, int n_dims, int nrows, const float * pos, float freq_scale, int p_delta_rows,
77127712
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
77137713
) {
77147714
GGML_ASSERT(ncols % 2 == 0);
@@ -7733,7 +7733,7 @@ static void rope_neox_cuda(
77337733
}
77347734

77357735
static void rope_glm_f32_cuda(
7736-
const float * x, float * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
7736+
const float * x, float * dst, int ncols, int nrows, const float * pos, float freq_scale, int p_delta_rows,
77377737
float freq_base, int n_ctx, cudaStream_t stream
77387738
) {
77397739
GGML_ASSERT(ncols % 4 == 0);
@@ -9035,11 +9035,11 @@ static void ggml_cuda_op_rope(
90359035
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
90369036
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
90379037

9038-
const int32_t * pos = nullptr;
9038+
const float * pos = nullptr;
90399039
if ((mode & 1) == 0) {
9040-
GGML_ASSERT(src1->type == GGML_TYPE_I32);
9040+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
90419041
GGML_ASSERT(src1->ne[0] == ne2);
9042-
pos = (const int32_t *) src1_dd;
9042+
pos = (const float *) src1_dd;
90439043
}
90449044

90459045
const bool is_neox = mode & 2;

ggml-metal.m

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2057,7 +2057,13 @@ static bool ggml_metal_graph_compute(
20572057
// skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
20582058
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
20592059

2060-
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
2060+
float freq_base;
2061+
float freq_scale;
2062+
float ext_factor;
2063+
float attn_factor;
2064+
float beta_fast;
2065+
float beta_slow;
2066+
20612067
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
20622068
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
20632069
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));

ggml-metal.metal

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1674,7 +1674,7 @@ static void rope_yarn_corr_dims(
16741674

16751675
typedef void (rope_t)(
16761676
device const void * src0,
1677-
device const int32_t * src1,
1677+
device const float * src1,
16781678
device float * dst,
16791679
constant int64_t & ne00,
16801680
constant int64_t & ne01,
@@ -1709,7 +1709,7 @@ typedef void (rope_t)(
17091709
template<typename T>
17101710
kernel void kernel_rope(
17111711
device const void * src0,
1712-
device const int32_t * src1,
1712+
device const float * src1,
17131713
device float * dst,
17141714
constant int64_t & ne00,
17151715
constant int64_t & ne01,
@@ -1749,11 +1749,11 @@ kernel void kernel_rope(
17491749
float corr_dims[2];
17501750
rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
17511751

1752-
device const int32_t * pos = src1;
1752+
device const float * pos = src1;
17531753

1754-
const int64_t p = pos[i2];
1754+
const float p = pos[i2];
17551755

1756-
const float theta_0 = (float)p;
1756+
const float theta_0 = p;
17571757
const float inv_ndims = -1.f/n_dims;
17581758

17591759
if (!is_neox) {

ggml.c

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5254,7 +5254,7 @@ static struct ggml_tensor * ggml_rope_impl(
52545254
bool xpos_down,
52555255
bool inplace) {
52565256
GGML_ASSERT(ggml_is_vector(b));
5257-
GGML_ASSERT(b->type == GGML_TYPE_I32);
5257+
GGML_ASSERT(b->type == GGML_TYPE_F32);
52585258
GGML_ASSERT(a->ne[2] == b->ne[0]);
52595259

52605260
bool is_node = false;
@@ -5377,7 +5377,7 @@ struct ggml_tensor * ggml_rope_back(
53775377
float xpos_base,
53785378
bool xpos_down) {
53795379
GGML_ASSERT(ggml_is_vector(b));
5380-
GGML_ASSERT(b->type == GGML_TYPE_I32);
5380+
GGML_ASSERT(b->type == GGML_TYPE_F32);
53815381
GGML_ASSERT(a->ne[2] == b->ne[0]);
53825382

53835383
GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet");
@@ -12352,11 +12352,11 @@ static void ggml_compute_forward_rope_f32(
1235212352
// this essentially just switches the sign of sin.
1235312353
const float sin_sign = forward ? 1.0f : -1.0f;
1235412354

12355-
const int32_t * pos = (const int32_t *) src1->data;
12355+
const float * pos = (const float *) src1->data;
1235612356

1235712357
for (int64_t i3 = 0; i3 < ne3; i3++) {
1235812358
for (int64_t i2 = 0; i2 < ne2; i2++) {
12359-
const int64_t p = pos[i2];
12359+
const float p = pos[i2];
1236012360

1236112361
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
1236212362
if (!is_glm && !is_neox) { // TODO: cache sin/cos for glm, neox
@@ -12523,11 +12523,11 @@ static void ggml_compute_forward_rope_f16(
1252312523
// this essentially just switches the sign of sin.
1252412524
const float sin_sign = forward ? 1.0f : -1.0f;
1252512525

12526-
const int32_t * pos = (const int32_t *) src1->data;
12526+
const float * pos = (const float *) src1->data;
1252712527

1252812528
for (int64_t i3 = 0; i3 < ne3; i3++) {
1252912529
for (int64_t i2 = 0; i2 < ne2; i2++) {
12530-
const int64_t p = pos[i2];
12530+
const float p = pos[i2];
1253112531

1253212532
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
1253312533
if (!is_glm && !is_neox) { // TODO: cache sin/cos for glm, neox

llama.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1699,8 +1699,8 @@ struct llama_layer {
16991699
};
17001700

17011701
struct llama_kv_cell {
1702-
llama_pos pos = -1;
1703-
llama_pos delta = 0;
1702+
float pos = -1.0f;
1703+
float delta = 0.0f;
17041704

17051705
std::set<llama_seq_id> seq_id;
17061706

@@ -1939,10 +1939,10 @@ struct llama_context {
19391939
ggml_context * ctx_input = nullptr;
19401940
struct ggml_tensor * inp_tokens; // I32 [n_batch]
19411941
struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch]
1942-
struct ggml_tensor * inp_pos; // I32 [n_batch]
1942+
struct ggml_tensor * inp_pos; // F32 [n_batch]
19431943
struct ggml_tensor * inp_KQ_mask; // F32 [n_ctx, n_batch]
19441944
struct ggml_tensor * inp_KQ_pos; // F32 [n_ctx]
1945-
struct ggml_tensor * inp_K_shift; // I32 [n_ctx]
1945+
struct ggml_tensor * inp_K_shift; // F32 [n_ctx]
19461946
struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
19471947
struct ggml_tensor * inp_cls; // I32 [n_batch]
19481948

@@ -2222,7 +2222,7 @@ static void llama_kv_cache_seq_div(
22222222
llama_seq_id seq_id,
22232223
llama_pos p0,
22242224
llama_pos p1,
2225-
int d) {
2225+
float d) {
22262226
if (p0 < 0) p0 = 0;
22272227
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
22282228

@@ -7744,7 +7744,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
77447744

77457745
assert(ggml_backend_buffer_is_host(lctx.inp_K_shift->buffer));
77467746

7747-
int32_t * data = (int32_t *) lctx.inp_K_shift->data;
7747+
float * data = (float *) lctx.inp_K_shift->data;
77487748

77497749
for (int i = 0; i < n_ctx; ++i) {
77507750
data[i] = lctx.kv_self.cells[i].delta;
@@ -11690,10 +11690,10 @@ struct llama_context * llama_new_context_with_model(
1169011690

1169111691
ctx->inp_tokens = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
1169211692
ctx->inp_embd = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, hparams.n_embd, cparams.n_batch);
11693-
ctx->inp_pos = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
11693+
ctx->inp_pos = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_batch);
1169411694
ctx->inp_KQ_mask = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_ctx, cparams.n_batch);
1169511695
ctx->inp_KQ_pos = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_ctx);
11696-
ctx->inp_K_shift = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_ctx);
11696+
ctx->inp_K_shift = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_ctx);
1169711697
ctx->inp_mean = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_batch, cparams.n_batch);
1169811698
ctx->inp_cls = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
1169911699

@@ -12046,7 +12046,7 @@ void llama_kv_cache_seq_shift(struct llama_context * ctx, llama_seq_id seq_id, l
1204612046
llama_kv_cache_seq_shift(ctx->kv_self, seq_id, p0, p1, delta);
1204712047
}
1204812048

12049-
void llama_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
12049+
void llama_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, float d) {
1205012050
if (d == 1) {
1205112051
return;
1205212052
}
@@ -12461,7 +12461,7 @@ int llama_eval_embd(
1246112461
int32_t n_past) {
1246212462
llama_kv_cache_seq_rm(ctx->kv_self, -1, n_past, -1);
1246312463

12464-
llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, nullptr, n_past, 1, 0, };
12464+
llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, nullptr, (float) n_past, 1, 0, };
1246512465

1246612466
const int ret = llama_decode_internal(*ctx, batch);
1246712467
if (ret < 0) {

0 commit comments

Comments
 (0)