Skip to content

Why need prefill for new generated text token? #293

@Xer-GWX

Description

@Xer-GWX

Hi, thanks for your great work!

I have a question related to the code below:

if think:
            gen_text = self.gen_text(gen_context, do_sample=do_sample, temperature=text_temperature, max_length=max_think_token_n)
            gen_context = self.update_context_text(gen_text, gen_context)
            output_list.append(gen_text)

For gen_text(decode), it will generate kv_cache, but you still run the update_context_text(prefill) to get kv_cache. If i want to accelerate it, can this step(update_context_text) be skipped? Why you design like this?

I also print the kv_cache tensor in gen_text and update_context_text. The results shows they have 0.0000035 mean_abs which is quite small, more likely caused by precision difference instead of accuracy problem.

The print log is below:

=== Summary ===

[past_key_cache_after_decode.pt]
  tensor: shape=(3012, 4, 128) dtype=torch.bfloat16 device=cpu **mean=-0.0625086 std=5.32648 min=-208 max=55.5**

[past_key_values_after_prefill.pt]
  tensor: shape=(3013, 4, 128) dtype=torch.bfloat16 device=cpu **mean=-0.0625121 std=5.32567 min=-208 max=55.5**
  tensor (trimmed): shape=(3012, 4, 128) mean=-0.0625059 std=5.32648 min=-208 max=55.5
  tensor (last_row): mean=-0.0812546 std=1.58288 min=-7.125 max=7.03125

=== Pairwise comparisons (same key across files) ===

[past_key_cache_after_decode.pt] vs [past_key_values_after_prefill.pt]
  tensor: avg_abs_diff=0.00182607 max_abs_diff=0.785156

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions