@@ -322,49 +322,77 @@ AES_FUNC_END(aes_cbc_cts_decrypt)
322322 * This macro generates the code for CTR and XCTR mode.
323323 * /
324324.macro ctr_encrypt xctr
325+ // Arguments
326+ OUT .req x0
327+ IN .req x1
328+ KEY .req x2
329+ ROUNDS_W .req w3
330+ BYTES_W .req w4
331+ IV .req x5
332+ BYTE_CTR_W .req w6 // XCTR only
333+ // Intermediate values
334+ CTR_W .req w11 // XCTR only
335+ CTR .req x11 // XCTR only
336+ IV_PART .req x12
337+ BLOCKS .req x13
338+ BLOCKS_W .req w13
339+
325340 stp x29 , x30 , [ sp , # - 16 ] !
326341 mov x29 , sp
327342
328- enc_prepare w3 , x2 , x12
329- ld1 {vctr.16b} , [ x5 ]
343+ enc_prepare ROUNDS_W , KEY , IV_PART
344+ ld1 {vctr.16b} , [ IV ]
330345
346+ / *
347+ * Keep 64 bits of the IV in a register. For CTR mode this lets us
348+ * easily increment the IV. For XCTR mode this lets us efficiently XOR
349+ * the 64 - bit counter with the IV.
350+ * /
331351 .if \xctr
332- umov x12 , vctr.d [ 0 ]
333- lsr w11 , w6 , # 4
352+ umov IV_PART , vctr.d [ 0 ]
353+ lsr CTR_W , BYTE_CTR_W , # 4
334354 .else
335- umov x12 , vctr.d [ 1 ] / * keep swabbed ctr in reg * /
336- rev x12 , x12
355+ umov IV_PART , vctr.d [ 1 ]
356+ rev IV_PART , IV_PART
337357 .endif
338358
339359.LctrloopNx\xctr:
340- add w7 , w4 , # 15
341- sub w4 , w4 , #MAX_STRIDE << 4
342- lsr w7 , w7 , # 4
360+ add BLOCKS_W , BYTES_W , # 15
361+ sub BYTES_W , BYTES_W , #MAX_STRIDE << 4
362+ lsr BLOCKS_W , BLOCKS_W , # 4
343363 mov w8 , #MAX_STRIDE
344- cmp w7 , w8
345- csel w7 , w7 , w8 , lt
364+ cmp BLOCKS_W , w8
365+ csel BLOCKS_W , BLOCKS_W , w8 , lt
346366
367+ / *
368+ * Set up the counter values in v0 - v{MAX_STRIDE - 1 }.
369+ *
370+ * If we are encrypting less than MAX_STRIDE blocks , the tail block
371+ * handling code expects the last keystream block to be in
372+ * v{MAX_STRIDE - 1 }. For example: if encrypting two blocks with
373+ * MAX_STRIDE= 5 , then v3 and v4 should have the next two counter blocks.
374+ * /
347375 .if \xctr
348- add x11 , x11 , x7
376+ add CTR , CTR , BLOCKS
349377 .else
350- adds x12 , x12 , x7
378+ adds IV_PART , IV_PART , BLOCKS
351379 .endif
352380 mov v0.16b , vctr.16b
353381 mov v1.16b , vctr.16b
354382 mov v2.16b , vctr.16b
355383 mov v3.16b , vctr.16b
356384ST5( mov v4.16b , vctr.16b )
357385 .if \xctr
358- sub x6 , x11 , #MAX_STRIDE - 1
359- sub x7 , x11 , #MAX_STRIDE - 2
360- sub x8 , x11 , #MAX_STRIDE - 3
361- sub x9 , x11 , #MAX_STRIDE - 4
362- ST5( sub x10 , x11 , #MAX_STRIDE - 5 )
363- eor x6 , x6 , x12
364- eor x7 , x7 , x12
365- eor x8 , x8 , x12
366- eor x9 , x9 , x12
367- ST5( eor x10 , x10 , x12 )
386+ sub x6 , CTR , #MAX_STRIDE - 1
387+ sub x7 , CTR , #MAX_STRIDE - 2
388+ sub x8 , CTR , #MAX_STRIDE - 3
389+ sub x9 , CTR , #MAX_STRIDE - 4
390+ ST5( sub x10 , CTR , #MAX_STRIDE - 5 )
391+ eor x6 , x6 , IV_PART
392+ eor x7 , x7 , IV_PART
393+ eor x8 , x8 , IV_PART
394+ eor x9 , x9 , IV_PART
395+ ST5( eor x10 , x10 , IV_PART )
368396 mov v0.d [ 0 ], x6
369397 mov v1.d [ 0 ], x7
370398 mov v2.d [ 0 ], x8
@@ -373,17 +401,32 @@ ST5( mov v4.d[0], x10 )
373401 .else
374402 bcs 0f
375403 .subsection 1
376- / * apply carry to outgoing counter * /
404+ / *
405+ * This subsection handles carries.
406+ *
407+ * Conditional branching here is allowed with respect to time
408+ * invariance since the branches are dependent on the IV instead
409+ * of the plaintext or key. This code is rarely executed in
410+ * practice anyway.
411+ * /
412+
413+ / * Apply carry to outgoing counter. * /
3774140 : umov x8 , vctr.d [ 0 ]
378415 rev x8 , x8
379416 add x8 , x8 , # 1
380417 rev x8 , x8
381418 ins vctr.d [ 0 ], x8
382419
383- / * apply carry to N counter blocks for N := x12 * /
384- cbz x12 , 2f
420+ / *
421+ * Apply carry to counter blocks if needed.
422+ *
423+ * Since the carry flag was set , we know 0 <= IV_PART <
424+ * MAX_STRIDE. Using the value of IV_PART we can determine how
425+ * many counter blocks need to be updated.
426+ * /
427+ cbz IV_PART , 2f
385428 adr x16 , 1f
386- sub x16 , x16 , x12 , lsl # 3
429+ sub x16 , x16 , IV_PART , lsl # 3
387430 br x16
388431 bti c
389432 mov v0.d [ 0 ], vctr.d [ 0 ]
@@ -398,71 +441,88 @@ ST5( mov v4.d[0], vctr.d[0] )
3984411 : b 2f
399442 .previous
400443
401- 2 : rev x7 , x12
444+ 2 : rev x7 , IV_PART
402445 ins vctr.d [ 1 ], x7
403- sub x7 , x12 , #MAX_STRIDE - 1
404- sub x8 , x12 , #MAX_STRIDE - 2
405- sub x9 , x12 , #MAX_STRIDE - 3
446+ sub x7 , IV_PART , #MAX_STRIDE - 1
447+ sub x8 , IV_PART , #MAX_STRIDE - 2
448+ sub x9 , IV_PART , #MAX_STRIDE - 3
406449 rev x7 , x7
407450 rev x8 , x8
408451 mov v1.d [ 1 ], x7
409452 rev x9 , x9
410- ST5( sub x10 , x12 , #MAX_STRIDE - 4 )
453+ ST5( sub x10 , IV_PART , #MAX_STRIDE - 4 )
411454 mov v2.d [ 1 ], x8
412455ST5( rev x10 , x10 )
413456 mov v3.d [ 1 ], x9
414457ST5( mov v4.d [ 1 ], x10 )
415458 .endif
416- tbnz w4 , # 31 , .Lctrtail\xctr
417- ld1 {v5.16b - v7.16b} , [ x1 ], # 48
459+
460+ / *
461+ * If there are at least MAX_STRIDE blocks left , XOR the data with
462+ * keystream and store. Otherwise jump to tail handling.
463+ * /
464+ tbnz BYTES_W , # 31 , .Lctrtail\xctr
465+ ld1 {v5.16b - v7.16b} , [ IN ], # 48
418466ST4( bl aes_encrypt_block4x )
419467ST5( bl aes_encrypt_block5x )
420468 eor v0.16b , v5.16b , v0.16b
421- ST4( ld1 {v5.16b} , [ x1 ], # 16 )
469+ ST4( ld1 {v5.16b} , [ IN ], # 16 )
422470 eor v1.16b , v6.16b , v1.16b
423- ST5( ld1 {v5.16b - v6.16b} , [ x1 ], # 32 )
471+ ST5( ld1 {v5.16b - v6.16b} , [ IN ], # 32 )
424472 eor v2.16b , v7.16b , v2.16b
425473 eor v3.16b , v5.16b , v3.16b
426474ST5( eor v4.16b , v6.16b , v4.16b )
427- st1 {v0.16b - v3.16b} , [ x0 ], # 64
428- ST5( st1 {v4.16b} , [ x0 ], # 16 )
429- cbz w4 , .Lctrout\xctr
475+ st1 {v0.16b - v3.16b} , [ OUT ], # 64
476+ ST5( st1 {v4.16b} , [ OUT ], # 16 )
477+ cbz BYTES_W , .Lctrout\xctr
430478 b .LctrloopNx\xctr
431479
432480.Lctrout\xctr:
433481 .if !\xctr
434- st1 {vctr.16b} , [ x5 ] / * return next CTR value * /
482+ st1 {vctr.16b} , [ IV ] / * return next CTR value * /
435483 .endif
436484 ldp x29 , x30 , [ sp ], # 16
437485 ret
438486
439487.Lctrtail\xctr:
488+ / *
489+ * Handle up to MAX_STRIDE * 16 - 1 bytes of plaintext
490+ *
491+ * This code expects the last keystream block to be in v{MAX_STRIDE - 1 }.
492+ * For example: if encrypting two blocks with MAX_STRIDE= 5 , then v3 and
493+ * v4 should have the next two counter blocks.
494+ *
495+ * This allows us to store the ciphertext by writing to overlapping
496+ * regions of memory. Any invalid ciphertext blocks get overwritten by
497+ * correctly computed blocks. This approach greatly simplifies the
498+ * logic for storing the ciphertext.
499+ * /
440500 mov x16 , # 16
441- ands x6 , x4 , # 0xf
442- csel x13 , x6 , x16 , ne
501+ ands w7 , BYTES_W , # 0xf
502+ csel x13 , x7 , x16 , ne
443503
444- ST5( cmp w4 , # 64 - (MAX_STRIDE << 4 ) )
504+ ST5( cmp BYTES_W , # 64 - (MAX_STRIDE << 4 ))
445505ST5( csel x14 , x16 , xzr , gt )
446- cmp w4 , # 48 - (MAX_STRIDE << 4 )
506+ cmp BYTES_W , # 48 - (MAX_STRIDE << 4 )
447507 csel x15 , x16 , xzr , gt
448- cmp w4 , # 32 - (MAX_STRIDE << 4 )
508+ cmp BYTES_W , # 32 - (MAX_STRIDE << 4 )
449509 csel x16 , x16 , xzr , gt
450- cmp w4 , # 16 - (MAX_STRIDE << 4 )
510+ cmp BYTES_W , # 16 - (MAX_STRIDE << 4 )
451511
452- adr_l x12 , .Lcts_permute_table
453- add x12 , x12 , x13
512+ adr_l x9 , .Lcts_permute_table
513+ add x9 , x9 , x13
454514 ble .Lctrtail1x\xctr
455515
456- ST5( ld1 {v5.16b} , [ x1 ], x14 )
457- ld1 {v6.16b} , [ x1 ], x15
458- ld1 {v7.16b} , [ x1 ], x16
516+ ST5( ld1 {v5.16b} , [ IN ], x14 )
517+ ld1 {v6.16b} , [ IN ], x15
518+ ld1 {v7.16b} , [ IN ], x16
459519
460520ST4( bl aes_encrypt_block4x )
461521ST5( bl aes_encrypt_block5x )
462522
463- ld1 {v8.16b} , [ x1 ], x13
464- ld1 {v9.16b} , [ x1 ]
465- ld1 {v10.16b} , [ x12 ]
523+ ld1 {v8.16b} , [ IN ], x13
524+ ld1 {v9.16b} , [ IN ]
525+ ld1 {v10.16b} , [ x9 ]
466526
467527ST4( eor v6.16b , v6.16b , v0.16b )
468528ST4( eor v7.16b , v7.16b , v1.16b )
@@ -477,35 +537,70 @@ ST5( eor v7.16b, v7.16b, v2.16b )
477537ST5( eor v8.16b , v8.16b , v3.16b )
478538ST5( eor v9.16b , v9.16b , v4.16b )
479539
480- ST5( st1 {v5.16b} , [ x0 ], x14 )
481- st1 {v6.16b} , [ x0 ], x15
482- st1 {v7.16b} , [ x0 ], x16
483- add x13 , x13 , x0
540+ ST5( st1 {v5.16b} , [ OUT ], x14 )
541+ st1 {v6.16b} , [ OUT ], x15
542+ st1 {v7.16b} , [ OUT ], x16
543+ add x13 , x13 , OUT
484544 st1 {v9.16b} , [ x13 ] // overlapping stores
485- st1 {v8.16b} , [ x0 ]
545+ st1 {v8.16b} , [ OUT ]
486546 b .Lctrout\xctr
487547
488548.Lctrtail1x\xctr:
489- sub x7 , x6 , # 16
490- csel x6 , x6 , x7 , eq
491- add x1 , x1 , x6
492- add x0 , x0 , x6
493- ld1 {v5.16b} , [ x1 ]
494- ld1 {v6.16b} , [ x0 ]
549+ / *
550+ * Handle <= 16 bytes of plaintext
551+ *
552+ * This code always reads and writes 16 bytes. To avoid out of bounds
553+ * accesses , XCTR and CTR modes must use a temporary buffer when
554+ * encrypting/decrypting less than 16 bytes.
555+ *
556+ * This code is unusual in th at it loads the input and stores the output
557+ * relative to the end of the buffers rather than relative to the start.
558+ * This causes unusual behaviour when encrypting/decrypting less than 16
559+ * bytes ; the end of the data is expected to be at the end of the
560+ * temporary buffer rather than the start of the data being at the start
561+ * of the temporary buffer.
562+ * /
563+ sub x8 , x7 , # 16
564+ csel x7 , x7 , x8 , eq
565+ add IN , IN , x7
566+ add OUT , OUT , x7
567+ ld1 {v5.16b} , [ IN ]
568+ ld1 {v6.16b} , [ OUT ]
495569ST5( mov v3.16b , v4.16b )
496- encrypt_block v3 , w3 , x2 , x8 , w7
497- ld1 {v10.16b - v11.16b} , [ x12 ]
570+ encrypt_block v3 , ROUNDS_W , KEY , x8 , w7
571+ ld1 {v10.16b - v11.16b} , [ x9 ]
498572 tbl v3.16b , {v3.16b} , v10.16b
499573 sshr v11.16b , v11.16b , # 7
500574 eor v5.16b , v5.16b , v3.16b
501575 bif v5.16b , v6.16b , v11.16b
502- st1 {v5.16b} , [ x0 ]
576+ st1 {v5.16b} , [ OUT ]
503577 b .Lctrout\xctr
578+
579+ // Arguments
580+ .unreq OUT
581+ .unreq IN
582+ .unreq KEY
583+ .unreq ROUNDS_W
584+ .unreq BYTES_W
585+ .unreq IV
586+ .unreq BYTE_CTR_W // XCTR only
587+ // Intermediate values
588+ .unreq CTR_W // XCTR only
589+ .unreq CTR // XCTR only
590+ .unreq IV_PART
591+ .unreq BLOCKS
592+ .unreq BLOCKS_W
504593.endm
505594
506595 / *
507596 * aes_ctr_encrypt(u8 out [], u8 const in [], u8 const rk [], int rounds ,
508597 * int bytes , u8 ctr [] )
598+ *
599+ * The input and output buffers must always be at least 16 bytes even if
600+ * encrypting/decrypting less than 16 bytes. Otherwise out of bounds
601+ * accesses will occur. The data to be encrypted/decrypted is expected
602+ * to be at the end of this 16 - byte temporary buffer rather than the
603+ * start.
509604 * /
510605
511606AES_FUNC_START(aes_ctr_encrypt)
@@ -515,6 +610,12 @@ AES_FUNC_END(aes_ctr_encrypt)
515610 / *
516611 * aes_xctr_encrypt(u8 out [], u8 const in [], u8 const rk [], int rounds ,
517612 * int bytes , u8 const iv [], int byte_ctr)
613+ *
614+ * The input and output buffers must always be at least 16 bytes even if
615+ * encrypting/decrypting less than 16 bytes. Otherwise out of bounds
616+ * accesses will occur. The data to be encrypted/decrypted is expected
617+ * to be at the end of this 16 - byte temporary buffer rather than the
618+ * start.
518619 * /
519620
520621AES_FUNC_START(aes_xctr_encrypt)
0 commit comments