@@ -62,6 +62,28 @@ class AesCircuitTests : public AesCircuit<BitType> {
6262 }
6363 }
6464
65+ void testInverseShiftRowInPlace (std::vector<bool > plaintext) {
66+ std::array<std::array<std::array<bool , 8 >, 4 >, 4 > block;
67+ for (int k = 0 ; k < 4 ; ++k) {
68+ for (int i = 0 ; i < 4 ; i++) {
69+ for (int j = 0 ; j < 8 ; j++) {
70+ block[k][i][j] = plaintext[32 * k + 8 * i + j];
71+ }
72+ }
73+ }
74+
75+ AesCircuit<bool >::inverseShiftRowInPlace (block);
76+ for (int k = 0 ; k < 4 ; ++k) {
77+ for (int i = 0 ; i < 4 ; i++) {
78+ for (int j = 0 ; j < 8 ; j++) {
79+ EXPECT_EQ (
80+ block[k][i][j],
81+ plaintext[32 * ((((k - i) % 4 ) + 4 ) % 4 ) + 8 * i + j]);
82+ }
83+ }
84+ }
85+ }
86+
6587 void testWordConversion () {
6688 using ByteType = std::array<bool , 8 >;
6789 using WordType = std::array<ByteType, 4 >;
@@ -159,6 +181,12 @@ TEST(AesCircuitTest, testShiftRowInPlace) {
159181 test.testShiftRowInPlace (plaintext);
160182}
161183
184+ TEST (AesCircuitTest, testInverseShiftRowInPlace) {
185+ auto plaintext = generateRandomPlaintext ();
186+ AesCircuitTests<bool > test;
187+ test.testInverseShiftRowInPlace (plaintext);
188+ }
189+
162190TEST (AesCircuitTest, testWordConversion) {
163191 AesCircuitTests<bool > test;
164192 test.testWordConversion ();
@@ -352,6 +380,65 @@ TEST(AesCircuitTest, testAesCircuitEncrypt) {
352380 testAesCircuitEncrypt (std::make_unique<AesCircuitFactory<bool >>());
353381}
354382
383+ void testAesCircuitDecrypt (
384+ std::shared_ptr<AesCircuitFactory<bool >> AesCircuitFactory) {
385+ auto AesCircuit = AesCircuitFactory->create ();
386+
387+ std::random_device rd;
388+ std::mt19937_64 e (rd ());
389+ std::uniform_int_distribution<uint8_t > dist (0 , 0xFF );
390+ size_t blockNo = dist (e);
391+
392+ // generate random key
393+ __m128i key = _mm_set_epi32 (dist (e), dist (e), dist (e), dist (e));
394+ // generate random plaintext
395+ std::vector<uint8_t > plaintext;
396+ plaintext.reserve (blockNo * 16 );
397+ for (int i = 0 ; i < blockNo * 16 ; ++i) {
398+ plaintext.push_back (dist (e));
399+ }
400+ std::vector<__m128i> plaintextAES;
401+ loadValueToLocalAes (plaintext, plaintextAES);
402+
403+ // expand key
404+ engine::util::Aes truthAes (key);
405+ auto expandedKey = truthAes.expandEncryptionKey (key);
406+ // extract key and plaintext
407+ std::vector<uint8_t > extractedKeys;
408+ extractedKeys.reserve (176 );
409+ for (auto keyb : expandedKey) {
410+ loadValueFromLocalAes (keyb, extractedKeys);
411+ }
412+
413+ // convert key and plaintext into bool vector
414+ std::vector<bool > keyBits;
415+ keyBits.reserve (1408 );
416+ int8VecToBinaryVec (extractedKeys, keyBits);
417+ std::vector<bool > plaintextBits;
418+ plaintextBits.reserve (blockNo * 128 );
419+ int8VecToBinaryVec (plaintext, plaintextBits);
420+
421+ // encrypt in real aes
422+ truthAes.encryptInPlace (plaintextAES);
423+
424+ // extract ciphertext in real aes
425+ std::vector<uint8_t > ciphertextTruth;
426+ ciphertextTruth.reserve (blockNo * 16 );
427+ for (auto b : plaintextAES) {
428+ loadValueFromLocalAes (b, ciphertextTruth);
429+ }
430+ std::vector<bool > cipherextBitsTruth;
431+ cipherextBitsTruth.reserve (blockNo * 128 );
432+ int8VecToBinaryVec (ciphertextTruth, cipherextBitsTruth);
433+ // decrypt this ciphertext using our decrypt circuit
434+ auto decryptionBits = AesCircuit->decrypt (cipherextBitsTruth, keyBits);
435+ testVectorEq (decryptionBits, plaintextBits);
436+ }
437+
438+ TEST (AesCircuitTest, testAesCircuitDecrypt) {
439+ testAesCircuitDecrypt (std::make_unique<AesCircuitFactory<bool >>());
440+ }
441+
355442void testAesCircuitCtr (
356443 std::shared_ptr<AesCircuitCtrFactory<bool >> AesCircuitCtrFactory) {
357444 auto AesCircuitCtr = AesCircuitCtrFactory->create ();
0 commit comments