@@ -104,10 +104,29 @@ class llama_kv_cells_unified {
104
104
res.resize (n);
105
105
106
106
for (uint32_t j = 0 ; j < n; ++j) {
107
- res.pos [j] = pos[i + j];
108
- res.seq [j] = seq[i + j];
107
+ const auto idx = i + j;
109
108
110
- assert (shift[i + j] == 0 );
109
+ res.pos [j] = pos[idx];
110
+ res.seq [j] = seq[idx];
111
+
112
+ assert (shift[idx] == 0 );
113
+ }
114
+
115
+ return res;
116
+ }
117
+
118
+ llama_kv_cells_unified cp (const std::vector<uint32_t > & idxs) const {
119
+ llama_kv_cells_unified res;
120
+
121
+ res.resize (idxs.size ());
122
+
123
+ for (uint32_t j = 0 ; j < idxs.size (); ++j) {
124
+ const auto idx = idxs[j];
125
+
126
+ res.pos [j] = pos[idx];
127
+ res.seq [j] = seq[idx];
128
+
129
+ assert (shift[idx] == 0 );
111
130
}
112
131
113
132
return res;
@@ -118,26 +137,57 @@ class llama_kv_cells_unified {
118
137
assert (i + other.pos .size () <= pos.size ());
119
138
120
139
for (uint32_t j = 0 ; j < other.pos .size (); ++j) {
121
- if (pos[i + j] == -1 && other.pos [j] != -1 ) {
140
+ const auto idx = i + j;
141
+
142
+ if (pos[idx] == -1 && other.pos [j] != -1 ) {
122
143
used.insert (i + j);
123
144
}
124
145
125
- if (pos[i + j ] != -1 && other.pos [j] == -1 ) {
146
+ if (pos[idx ] != -1 && other.pos [j] == -1 ) {
126
147
used.erase (i + j);
127
148
}
128
149
129
- if (pos[i + j ] != -1 ) {
150
+ if (pos[idx ] != -1 ) {
130
151
seq_pos_rm (i + j);
131
152
}
132
153
133
- pos[i + j ] = other.pos [j];
134
- seq[i + j ] = other.seq [j];
154
+ pos[idx ] = other.pos [j];
155
+ seq[idx ] = other.seq [j];
135
156
136
- if (pos[i + j ] != -1 ) {
157
+ if (pos[idx ] != -1 ) {
137
158
seq_pos_add (i + j);
138
159
}
139
160
140
- assert (shift[i + j] == 0 );
161
+ assert (shift[idx] == 0 );
162
+ }
163
+ }
164
+
165
+ void set (const std::vector<uint32_t > & idxs, const llama_kv_cells_unified & other) {
166
+ assert (idxs.size () == other.pos .size ());
167
+
168
+ for (uint32_t j = 0 ; j < other.pos .size (); ++j) {
169
+ const auto idx = idxs[j];
170
+
171
+ if (pos[idx] == -1 && other.pos [j] != -1 ) {
172
+ used.insert (idx);
173
+ }
174
+
175
+ if (pos[idx] != -1 && other.pos [j] == -1 ) {
176
+ used.erase (idx);
177
+ }
178
+
179
+ if (pos[idx] != -1 ) {
180
+ seq_pos_rm (idx);
181
+ }
182
+
183
+ pos[idx] = other.pos [j];
184
+ seq[idx] = other.seq [j];
185
+
186
+ if (pos[idx] != -1 ) {
187
+ seq_pos_add (idx);
188
+ }
189
+
190
+ assert (shift[idx] == 0 );
141
191
}
142
192
}
143
193
0 commit comments