x86/x64のAESNI命令群の使い方を解説します。
AES暗号については、米国の政府機関が出しているFIPS 197という文書がオリジナルの仕様になりますのでそちらを見ながら読んでください。
AES暗号ではキーサイズは128ビット、192ビット、256ビットの3種類があり、AES-128、AES-192、AES-256と呼ばれます。
各キーサイズごとに固有の定数Nk(キー長)、Nb(ブロック長)、Nr(ラウンド回数)があり下のように決まっています。NkとNbの単位はwordで、AESでは「word」は32ビットです。
(FIPS 197より抜粋)
煩わしいので以下Nbは4と書きます。
AES暗号ではどのキー長でもブロック長は4ワード(16バイト)固定です。暗号化では16バイトの平文を入力して16バイトの暗号文を得ます。復号では16バイトの暗号文を入力して16バイトの平文を得ます。
state配列は4*4バイトの二次元配列で、ここに平文を入れたあと、roundと呼ばれる処理を繰り返して加工することで暗号文を作ります。roundはNr回繰り返します。
(FIPS 197より抜粋)
AESNIでは、XMMWORDひとつにこの配列を丸ごと入れて処理します。
AESNI命令の使用の際には、バイトオーダー(エンディアン)の変換を行う必要はありません。
w配列は、Nkワードのキーをもとに、あらかじめKeyExpansionという処理をして作っておく4 * (Nr + 1)ワードの一次元配列です。round 1回ごとに4ワードずつ使っていきます。roundの周回に入る前に下処理でひとつ使うのでNr+1回分が必要です。
AESNIでは、XMMWORDひとつにround 1回分(4ワード)の要素を丸ごと入れて処理します(AESのワードは32ビットです)。
以下はFIPS 197に書いてある暗号化のアルゴリズムです。色は解説の都合上私がつけたものです。AESNI命令ではAとBの部分を行います。CのAddRoundKey()はxorしているだけなのでpxor命令でできます。
AESENC xmm1, xmm2/m128 (AESNI
__m128i _mm_aesenc_si128(__m128i state, __m128i w);
AESENCLAST xmm1, xmm2/m128 (AESNI
__m128i _mm_aesenclast_si128(__m128i state, __m128i w);
VAESENC xmm1, xmm2, xmm3/m128 (AESNI+(V1
VAESENCLAST xmm1, xmm2, xmm3/m128 (AESNI+(V1
VAESENC ymm1, ymm2, ymm3/m256 (VAES
__m256i _mm256_aesenc_epi128(__m256i state, __m256i w);
VAESENCLAST ymm1, ymm2, ymm3/m256 (VAES
__m256i _mm256_aesenclast_epi128(__m256i state, __m256i w);
VAESENC zmm1, zmm2, zmm3/m512 (VAES+(V5
__m512i _mm512_aesenc_epi128(__m512i state, __m512i w);
VAESENCLAST zmm1, zmm2, zmm3/m512 (VAES+(V5
__m512i _mm512_aesenclast_epi128(__m512i state, __m512i w);
AESENCは、Aの4つの関数を呼び出している部分(1ラウンド分)を1命令で処理します。
Bは最終ラウンドですが、この回だけ呼び出す関数がひとつ少ないので、AESENCLAST命令で処理します。
①にラウンド前のstate配列、②にそのラウンドで使うw配列の要素を入れて実行すると、③にラウンド後のstate配列が返ります。
命令の中で具体的に何をやっているかはFIPS 197の各関数の説明をご覧ください。
暗号化の前にKey Expansionで暗号キー(Nkワードのデータ)からw配列を作ります。
以下はFIPS 197のKeyExpansionのアルゴリズムです。
Rcon配列はFIPS 197で定義された方法で計算された定数の配列で、具体的な値は以下のようになります。AES暗号で実際に使われるのは[1]~[10]の範囲だけです。
static const BYTE Rcon[] = {
0x8d, 0x01, 0x02, 0x04, 0x08, 0x10,
0x20, 0x40, 0x80, 0x1b, 0x36, 0x6c, 0xd8, 0xab, 0x4d, 0x9a,
};
keyをw配列にコピーする際に、バイトオーダー(エンディアン)の変換を行う必要はありません。
AESKEYGENASSIST xmm1, xmm2/m128, imm8 (AESNI
VAESKEYGENASSIST xmm1, xmm2/m128, imm8 (AESNI+(V1
__m128i _mm_aeskeygenassist_si128(__m128i temp, const int Rcon);
AESKEYGENASSISTは、①にtempの値を入れて、②にRcon[i/Nk]の値を指定して実行すると、A、Bの2つの式の値を計算して③に返してくれます。
FIPS 197では、復号のアルゴリズムとして、暗号化の手順を逆にしただけの「InvCipher」と、それと等価な結果を得られる「EqInvCipher」という2つが定義されていますが、AESNIがサポートするのは後者のみです。
EqInvCipherでは、w配列の生成処理に後処理をちょっと追加して作ったdw配列を使います。dw配列の作り方は後述します。
AESDEC xmm1, xmm2/m128 (AESNI
__m128i _mm_aesdec_si128(__m128i state, __m128i dw);
AESDECLAST xmm1, xmm2/m128 (AESNI
__m128i _mm_aesdeclast_si128(__m128i state, __m128i dw);
VAESDEC xmm1, xmm2, xmm3/m128 (AESNI+(V1
VAESDECLAST xmm1, xmm2, xmm3/m128 (AESNI+(V1
VAESDEC ymm1, ymm2, ymm3/m256 (VAES
__m256i _mm256_aesdec_epi128(__m256i state, __m256i dw);
VAESDECLAST ymm1, ymm2, ymm3/m256 (VAES
__m256i _mm256_aesdeclast_epi128(__m256i state, __m256i dw);
VAESDEC zmm1, zmm2, zmm3/m512 (VAES+(V5
__m512i _mm512_aesdec_epi128(__m512i state, __m512i dw);
VAESDECLAST zmm1, zmm2, zmm3/m512 (VAES+(V5
__m512i _mm512_aesdeclast_epi128(__m512i state, __m512i dw);
AESDECは、Aの4つの関数を呼び出している部分(1ラウンド分)を1命令で処理します。
Bは最終ラウンドですが、この回だけ呼び出す関数がひとつ少ないので、AESDECLAST命令で処理します。
①にラウンド前のstate配列、②にそのラウンドで使うdw配列の要素を入れて実行すると、③にラウンド後のstate配列が返ります。
EqInvCipher用のdw配列は、前述のKeyExpansion処理で作ったw配列に、以下の後処理をすることで作ります。
AESIMC xmm1, xmm2/m128 (AESNI
VAESIMC xmm1, xmm2/m128 (AESNI+(V1
__m128i _mm_aesimc_si128(__m128i w);
上のInvMixColumns()の部分を処理する命令です。w配列の要素を①に入れて実行すると、後処理後のdw配列の要素が得られます。
AES-128 AES-192 AES-256
#pragma once
#include <intrin.h>
class AES128_NI
{
public:
// AES128
static const int Nk = 4;
static const int Nb = 4;
static const int Nr = 10;
protected:
__m128i w128[Nr + 1];
__m128i dw128[Nr + 1];
bool decrypting;
public:
AES128_NI(const unsigned char key[4 * Nk], bool decrypting)
: decrypting(decrypting)
{
KeyExpansion(key);
}
void Cipher(const unsigned char in[4 * Nb], unsigned char out[4 * Nb]);
void EqInvCipher(const unsigned char in[4 * Nb], unsigned char out[4 * Nb]);
protected:
void KeyExpansion(const unsigned char key[4 * Nk]);
};
#include "AES128_NI.h" //Cipher(byte in[4*Nb], byte out[4*Nb], word w[Nb*(Nr+1)]) //begin void AES128_NI::Cipher(const unsigned char in[4 * Nb], unsigned char out[4 * Nb]) { // ASSERT(!decrypting); //state = in __m128i state = _mm_loadu_si128(reinterpret_cast<const __m128i *>(in)); //AddRoundKey(state, w[0, Nb-1]) // just XOR state = _mm_xor_si128(state, w128[0]); //for round = 1 step 1 to Nr-1 //SubBytes(state) //ShiftRows(state) //MixColumns(state) //AddRoundKey(state, w[round*Nb, (round+1)*Nb-1]) //end for state = _mm_aesenc_si128(state, w128[1]); state = _mm_aesenc_si128(state, w128[2]); state = _mm_aesenc_si128(state, w128[3]); state = _mm_aesenc_si128(state, w128[4]); state = _mm_aesenc_si128(state, w128[5]); state = _mm_aesenc_si128(state, w128[6]); state = _mm_aesenc_si128(state, w128[7]); state = _mm_aesenc_si128(state, w128[8]); state = _mm_aesenc_si128(state, w128[9]); // The last round //SubBytes(state) //ShiftRows(state) //AddRoundKey(state, w[Nr*Nb, (Nr+1)*Nb-1]) state = _mm_aesenclast_si128(state, w128[Nr]); //out = state _mm_storeu_si128(reinterpret_cast<__m128i *>(out), state); //end } //EqInvCipher(byte in[4*Nb], byte out[4*Nb], word dw[Nb*(Nr+1)]) //begin void AES128_NI::EqInvCipher(const unsigned char in[4 * Nb], unsigned char out[4 * Nb]) { // ASSERT(decrypting); //byte state[4,Nb] //state = in __m128i state = _mm_loadu_si128(reinterpret_cast<const __m128i *>(in)); //AddRoundKey(state, dw[Nr*Nb, (Nr+1)*Nb-1]) state = _mm_xor_si128(state, w128[Nr]); //for round = Nr-1 step -1 downto 1 //InvSubBytes(state) //InvShiftRows(state) //InvMixColumns(state) //AddRoundKey(state, dw[round*Nb, (round+1)*Nb-1]) //end for state = _mm_aesdec_si128(state, dw128[9]); state = _mm_aesdec_si128(state, dw128[8]); state = _mm_aesdec_si128(state, dw128[7]); state = _mm_aesdec_si128(state, dw128[6]); state = _mm_aesdec_si128(state, dw128[5]); state = _mm_aesdec_si128(state, dw128[4]); state = _mm_aesdec_si128(state, dw128[3]); state = _mm_aesdec_si128(state, dw128[2]); state = _mm_aesdec_si128(state, dw128[1]); //InvSubBytes(state) //InvShiftRows(state) //AddRoundKey(state, dw[0, Nb-1]) state = _mm_aesdeclast_si128(state, dw128[0]); //out = state _mm_storeu_si128(reinterpret_cast<__m128i *>(out), state); //end } //static const BYTE Rcon[] = { // 0x8d, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36, 0x6c, 0xd8, 0xab, 0x4d, 0x9a, //}; //KeyExpansion(byte key[4*Nk], word w[Nb*(Nr+1)], Nk) //begin void AES128_NI::KeyExpansion(const unsigned char key[4 * Nk]) { //word temp //i = 0 //while (i < Nk) //w[i] = word(key[4*i], key[4*i+1], key[4*i+2], key[4*i+3]) //i = i+1 //end while __m128i work = _mm_loadu_si128(reinterpret_cast<const __m128i*>(key)); // work = w3 : w2 : w1 : w0 w128[0] = work; //i = Nk //while (i < Nb * (Nr+1)] //temp = w[i-1] //if (i mod Nk = 0) //temp = SubWord(RotWord(temp)) xor Rcon[i/Nk] //else if (Nk > 6 and i mod Nk = 4) //temp = SubWord(temp) //end if //w[i] = w[i-Nk] xor temp //i = i + 1 //end while __m128i t, t2; // f(w) = SubWord(RotWord(w)) xor Rcon // work = w3 : w2 : w1 : w0 // w4 = w0^f(w3) // w5 = w1^w4 = w1^w0^f(w3) // w6 = w2^w5 = w2^w1^w0^f(w3) // w7 = w3^w6 = w3^w2^w1^w0^f(w3) #define EXPAND(n, RCON) \ t = _mm_slli_si128(work, 4); /* t = w2 : w1 : w0 : 0 */\ t = _mm_xor_si128(work, t); /* t = w3^w2 : w2^w1 : w1^w0 : w0 */\ t2 = _mm_slli_si128(t, 8); /* t2 = w1^w0 : w0 : 0 : 0 */\ t = _mm_xor_si128(t, t2); /* t = w3^w2^w1^w0 : w2^w1^w0 : w1^w0 : w0 */\ work = _mm_aeskeygenassist_si128(work, RCON); /* work = f(w3) : - : - : - */ \ work = _mm_shuffle_epi32(work, 0xFF);/*work = f(w3) : f(w3) : f(w3) ; f(w3) */\ work = _mm_xor_si128(t, work); /* work = w3^w2^w1^w0^f(w3) : w2^w1^w0^f(w3) : w1^w0^f(w3) : w0^f(w3) */\ w128[n] = work; /* work = w7 : w6 : w5 : w4 */ EXPAND(1, 0x01); // Go on... EXPAND(2, 0x02); EXPAND(3, 0x04); EXPAND(4, 0x08); EXPAND(5, 0x10); EXPAND(6, 0x20); EXPAND(7, 0x40); EXPAND(8, 0x80); EXPAND(9, 0x1b); EXPAND(10, 0x36); // Additional process for EqInvCipher if (decrypting) { //for i = 0 step 1 to (Nr+1)*Nb-1 //dw[i] = w[i] //end for dw128[0] = w128[0]; //for round = 1 step 1 to Nr-1 //InvMixColumns(dw[round*Nb, (round+1)*Nb-1]) //end for dw128[1] = _mm_aesimc_si128(w128[1]); dw128[2] = _mm_aesimc_si128(w128[2]); dw128[3] = _mm_aesimc_si128(w128[3]); dw128[4] = _mm_aesimc_si128(w128[4]); dw128[5] = _mm_aesimc_si128(w128[5]); dw128[6] = _mm_aesimc_si128(w128[6]); dw128[7] = _mm_aesimc_si128(w128[7]); dw128[8] = _mm_aesimc_si128(w128[8]); dw128[9] = _mm_aesimc_si128(w128[9]); dw128[Nr] = w128[Nr]; } //end }
#pragma once
#include <intrin.h>
class AES192_NI
{
public:
// AES192
static const int Nk = 6;
static const int Nb = 4;
static const int Nr = 12;
protected:
__m128i w128[Nr + 1];
__m128i dw128[Nr + 1];
bool decrypting;
public:
AES192_NI(const unsigned char key[4 * Nk], bool decrypting)
: decrypting(decrypting)
{
KeyExpansion(key);
}
void Cipher(const unsigned char in[4 * Nb], unsigned char out[4 * Nb]);
void EqInvCipher(const unsigned char in[4 * Nb], unsigned char out[4 * Nb]);
protected:
void KeyExpansion(const unsigned char key[4 * Nk]);
};
#include "AES192_NI.h" //Cipher(byte in[4*Nb], byte out[4*Nb], word w[Nb*(Nr+1)]) //begin void AES192_NI::Cipher(const unsigned char in[4 * Nb], unsigned char out[4 * Nb]) { // ASSERT(!decrypting); //state = in __m128i state = _mm_loadu_si128(reinterpret_cast<const __m128i*>(in)); //AddRoundKey(state, w[0, Nb-1]) // just XOR state = _mm_xor_si128(state, w128[0]); //for round = 1 step 1 to Nr-1 //SubBytes(state) //ShiftRows(state) //MixColumns(state) //AddRoundKey(state, w[round*Nb, (round+1)*Nb-1]) //end for state = _mm_aesenc_si128(state, w128[1]); state = _mm_aesenc_si128(state, w128[2]); state = _mm_aesenc_si128(state, w128[3]); state = _mm_aesenc_si128(state, w128[4]); state = _mm_aesenc_si128(state, w128[5]); state = _mm_aesenc_si128(state, w128[6]); state = _mm_aesenc_si128(state, w128[7]); state = _mm_aesenc_si128(state, w128[8]); state = _mm_aesenc_si128(state, w128[9]); state = _mm_aesenc_si128(state, w128[10]); state = _mm_aesenc_si128(state, w128[11]); // The last round //SubBytes(state) //ShiftRows(state) //AddRoundKey(state, w[Nr*Nb, (Nr+1)*Nb-1]) state = _mm_aesenclast_si128(state, w128[Nr]); //out = state _mm_storeu_si128(reinterpret_cast<__m128i*>(out), state); //end } //EqInvCipher(byte in[4*Nb], byte out[4*Nb], word dw[Nb*(Nr+1)]) //begin void AES192_NI::EqInvCipher(const unsigned char in[4 * Nb], unsigned char out[4 * Nb]) { // ASSERT(decrypting); //byte state[4,Nb] //state = in __m128i state = _mm_loadu_si128(reinterpret_cast<const __m128i*>(in)); //AddRoundKey(state, dw[Nr*Nb, (Nr+1)*Nb-1]) state = _mm_xor_si128(state, dw128[Nr]); //for round = Nr-1 step -1 downto 1 //InvSubBytes(state) //InvShiftRows(state) //InvMixColumns(state) //AddRoundKey(state, dw[round*Nb, (round+1)*Nb-1]) //end for state = _mm_aesdec_si128(state, dw128[11]); state = _mm_aesdec_si128(state, dw128[10]); state = _mm_aesdec_si128(state, dw128[9]); state = _mm_aesdec_si128(state, dw128[8]); state = _mm_aesdec_si128(state, dw128[7]); state = _mm_aesdec_si128(state, dw128[6]); state = _mm_aesdec_si128(state, dw128[5]); state = _mm_aesdec_si128(state, dw128[4]); state = _mm_aesdec_si128(state, dw128[3]); state = _mm_aesdec_si128(state, dw128[2]); state = _mm_aesdec_si128(state, dw128[1]); //InvSubBytes(state) //InvShiftRows(state) //AddRoundKey(state, dw[0, Nb-1]) state = _mm_aesdeclast_si128(state, dw128[0]); //out = state _mm_storeu_si128(reinterpret_cast<__m128i*>(out), state); //end } //static const BYTE Rcon[] = { // 0x8d, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36, 0x6c, 0xd8, 0xab, 0x4d, 0x9a, //}; //KeyExpansion(byte key[4*Nk], word w[Nb*(Nr+1)], Nk) //begin void AES192_NI::KeyExpansion(const unsigned char key[4 * Nk]) { //word temp //i = 0 //while (i < Nk) //w[i] = word(key[4*i], key[4*i+1], key[4*i+2], key[4*i+3]) //i = i+1 //end while __m128i work1 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(key)); // w3 : w2 : w1 : w0 __m128i work2 = _mm_loadl_epi64(reinterpret_cast<const __m128i*>(key + 16)); // - : - : w5 : w4 w128[0] = work1; //i = Nk //while (i < Nb * (Nr+1)] //temp = w[i-1] //if (i mod Nk = 0) //temp = SubWord(RotWord(temp)) xor Rcon[i/Nk] //else if (Nk > 6 and i mod Nk = 4) //temp = SubWord(temp) //end if //w[i] = w[i-Nk] xor temp //i = i + 1 //end while __m128i t, t2; __m128i work3; // f(w) = SubWord(RotWord(w)) xor Rcon // work1 = w3 : w2 : w1 : w0 // work2 = - : - : w5 : w4 // w6 = w0^f(w5) // w7 = w1^w6 = w1^w0^f(w5) t = _mm_aeskeygenassist_si128(work2, 0x01); /* t = - : - : f(w5) : - */ t = _mm_shuffle_epi32(t, 0x55); /* t = f(w5) : f(w5) : f(w5) : f(w5) */ t2 = _mm_slli_si128(work1, 4); /* t2 = - : - : w0 : 0 */ t2 = _mm_xor_si128(work1, t2); /* t2 = - : - : w1^w0 : w0 */ t = _mm_xor_si128(t2, t); /* t = - : - : w7 : w6 */ work2 = _mm_unpacklo_epi64(work2, t); /* work2 = w7 : w6 : w5 : w4 */ w128[1] = work2; // work1 = w3 : w2 : w1 : w0 // work2 = w7 : w6 : w5 : w4 // w8 = w2^w7 // w9 = w3^w8 = w3^w2^w7 // w10 = w4^w9 = w4^w3^w2^w7 // w11 = w5^w10 = w5^w4^w3^w2^w7 #define EXPAND1(n) \ t = _mm_alignr_epi8(work2, work1, 8);/* t = w5 : w4 : w3 : w2 */\ work3 = _mm_shuffle_epi32(work2, 0xFF); /* work3 = w7 : w7 : w7 : w7 */\ t2 = _mm_slli_si128(t, 4); /* t2 = w4 : w3 : w2 : 0 */\ t = _mm_xor_si128(t, t2); /* t = w5^w4 : w4^w3 : w3^w2 : w2 */\ t2 = _mm_slli_si128(t, 8); /* t2 = w3^w2 : w2 : 0 : 0 */\ t2 = _mm_xor_si128(t, t2); /* t2 = w5^w4^w3^w2 : w4^w3^w2 : w3^w2 : w2 */\ work3 = _mm_xor_si128(t2, work3); /* work3 = w11 : w10 : w9 : w8 */\ w128[n] = work3; EXPAND1(2); // work2 = w7 : w6 : w5 : w4 // work3 = w11: w10 : w9 : w8 // w12 = w6^f(w11) // w13 = w7^w12 = w7^w6^f(w11) // w14 = w8^w13 = w8^w7^w6^f(w11) // w15 = w9^w14 = w9^w8^w7^w6^f(w11) #define EXPAND2(n, RCON) \ t = _mm_alignr_epi8(work3, work2, 8);/*t = w9 : w8 : w7 : w6 */\ work1 = _mm_aeskeygenassist_si128(work3, RCON);/* work1 = f(w11) : - : - : - */\ work1 = _mm_shuffle_epi32(work1, 0xFF); /* work1 = f(w11) : f(w11) : f(w11) : f(w11) */\ t2 = _mm_slli_si128(t, 4); /* t2 = w8 : w7 : w6 : 0 */\ t = _mm_xor_si128(t, t2); /* t = w9^w8 : w8^w7 : w7^w6 : w6 */\ t2 = _mm_slli_si128(t, 8); /* t2 = w7^w6 : w6 : 0 : 0 */\ t = _mm_xor_si128(t, t2); /* t = w9^w8^w7^w6 : w8^w7^w6 : w7^w6 : w6 */\ work1 = _mm_xor_si128(t, work1); /* work1 = w15 : w14 : w13 : w12 */\ w128[n] = work1; EXPAND2(3, 0x02); const __m128i idx1 = _mm_set_epi8(-1, -1, -1, -1, -1, -1, -1, -1, 15, 14, 13, 12, 15, 14, 13, 12); const __m128i idx2 = _mm_set_epi8(7, 6, 5, 4, 7, 6, 5, 4, -1, -1, -1, -1, -1, -1, -1, -1); // work3 = w11 : w10 : w9 : w8 // work1 = w15 : w14 : w13 : w12 // w16 = w10^w15 // w17 = w11^w16 = w11^w10^w15 // w18 = w12^f(w17) // w19 = w13^w18 = w13^w12^f(w17) #define EXPAND3(n, RCON) \ work2 = _mm_alignr_epi8(work1, work3, 8);/*work2 = w13 : w12 : w11 : w10 */\ t = _mm_slli_epi64(work2, 32); /* t = w12 : 0 : w10 : 0 */\ work2 = _mm_xor_si128(work2, t); /* work2 = w13^w12 : w12 : w11^w10 : w10 */\ t = _mm_shuffle_epi8(work1, idx1); /* t = 0 : 0 : w15 : w15 */\ work2 = _mm_xor_si128(work2, t); /* work2 = w13^w12 : w12 : w17 : w16 */\ t = _mm_aeskeygenassist_si128(work2, RCON);/*t=- : - : f(w17) : - */\ t = _mm_shuffle_epi8(t, idx2); /* t = f(w17) : f(w17) : 0 : 0 */\ work2 = _mm_xor_si128(work2, t); /* work2 = w19 : w18 : w17 : w16 */\ w128[n] = work2; EXPAND3(4, 0x04); // Go on... EXPAND1(5); // w20-w23 EXPAND2(6, 0x08); // w24-w27 EXPAND3(7, 0x10); // w28-w31 EXPAND1(8); // w32-w35 EXPAND2(9, 0x20); // w36-w39 EXPAND3(10, 0x40); // w40-w43 EXPAND1(11); // w44-w47 EXPAND2(12, 0x80); // w48-w51 // Additional process for EqInvCipher if (decrypting) { //for i = 0 step 1 to (Nr+1)*Nb-1 //dw[i] = w[i] //end for dw128[0] = w128[0]; //for round = 1 step 1 to Nr-1 //InvMixColumns(dw[round*Nb, (round+1)*Nb-1]) //end for dw128[1] = _mm_aesimc_si128(w128[1]); dw128[2] = _mm_aesimc_si128(w128[2]); dw128[3] = _mm_aesimc_si128(w128[3]); dw128[4] = _mm_aesimc_si128(w128[4]); dw128[5] = _mm_aesimc_si128(w128[5]); dw128[6] = _mm_aesimc_si128(w128[6]); dw128[7] = _mm_aesimc_si128(w128[7]); dw128[8] = _mm_aesimc_si128(w128[8]); dw128[9] = _mm_aesimc_si128(w128[9]); dw128[10] = _mm_aesimc_si128(w128[10]); dw128[11] = _mm_aesimc_si128(w128[11]); dw128[Nr] = w128[Nr]; } //end }
#pragma once
#include <intrin.h>
class AES256_NI
{
public:
// AES256
static const int Nk = 8;
static const int Nb = 4;
static const int Nr = 14;
protected:
__m128i w128[Nr + 1];
__m128i dw128[Nr + 1];
bool decrypting;
public:
AES256_NI(const unsigned char key[4 * Nk], bool decrypting)
: decrypting(decrypting)
{
KeyExpansion(key);
}
void Cipher(const unsigned char in[4 * Nb], unsigned char out[4 * Nb]);
void EqInvCipher(const unsigned char in[4 * Nb], unsigned char out[4 * Nb]);
protected:
void KeyExpansion(const unsigned char key[4 * Nk]);
};
#include "AES256_NI.h" //Cipher(byte in[4*Nb], byte out[4*Nb], word w[Nb*(Nr+1)]) //begin void AES256_NI::Cipher(const unsigned char in[4 * Nb], unsigned char out[4 * Nb]) { // ASSERT(!decrypting); //state = in __m128i state = _mm_loadu_si128(reinterpret_cast<const __m128i *>(in)); //AddRoundKey(state, w[0, Nb-1]) // Just XOR state = _mm_xor_si128(state, w128[0]); //for round = 1 step 1 to Nr-1 //SubBytes(state) //ShiftRows(state) //MixColumns(state) //AddRoundKey(state, w[round*Nb, (round+1)*Nb-1]) //end for state = _mm_aesenc_si128(state, w128[1]); state = _mm_aesenc_si128(state, w128[2]); state = _mm_aesenc_si128(state, w128[3]); state = _mm_aesenc_si128(state, w128[4]); state = _mm_aesenc_si128(state, w128[5]); state = _mm_aesenc_si128(state, w128[6]); state = _mm_aesenc_si128(state, w128[7]); state = _mm_aesenc_si128(state, w128[8]); state = _mm_aesenc_si128(state, w128[9]); state = _mm_aesenc_si128(state, w128[10]); state = _mm_aesenc_si128(state, w128[11]); state = _mm_aesenc_si128(state, w128[12]); state = _mm_aesenc_si128(state, w128[13]); // The last round //SubBytes(state) //ShiftRows(state) //AddRoundKey(state, w[Nr*Nb, (Nr+1)*Nb-1]) state = _mm_aesenclast_si128(state, w128[Nr]); //out = state _mm_storeu_si128(reinterpret_cast<__m128i *>(out), state); //end } //EqInvCipher(byte in[4*Nb], byte out[4*Nb], word dw[Nb*(Nr+1)]) //begin void AES256_NI::EqInvCipher(const unsigned char in[4 * Nb], unsigned char out[4 * Nb]) { // ASSERT(decrypting); //byte state[4,Nb] //state = in __m128i state = _mm_loadu_si128(reinterpret_cast<const __m128i *>(in)); //AddRoundKey(state, dw[Nr*Nb, (Nr+1)*Nb-1]) state = _mm_xor_si128(state, dw128[Nr]); //for round = Nr-1 step -1 downto 1 //InvSubBytes(state) //InvShiftRows(state) //InvMixColumns(state) //AddRoundKey(state, dw[round*Nb, (round+1)*Nb-1]) //end for state = _mm_aesdec_si128(state, dw128[13]); state = _mm_aesdec_si128(state, dw128[12]); state = _mm_aesdec_si128(state, dw128[11]); state = _mm_aesdec_si128(state, dw128[10]); state = _mm_aesdec_si128(state, dw128[9]); state = _mm_aesdec_si128(state, dw128[8]); state = _mm_aesdec_si128(state, dw128[7]); state = _mm_aesdec_si128(state, dw128[6]); state = _mm_aesdec_si128(state, dw128[5]); state = _mm_aesdec_si128(state, dw128[4]); state = _mm_aesdec_si128(state, dw128[3]); state = _mm_aesdec_si128(state, dw128[2]); state = _mm_aesdec_si128(state, dw128[1]); //InvSubBytes(state) //InvShiftRows(state) //AddRoundKey(state, dw[0, Nb-1]) state = _mm_aesdeclast_si128(state, dw128[0]); //out = state _mm_storeu_si128(reinterpret_cast<__m128i *>(out), state); //end } //static const BYTE Rcon[] = { // 0x8d, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36, 0x6c, 0xd8, 0xab, 0x4d, 0x9a, //}; //KeyExpansion(byte key[4*Nk], word w[Nb*(Nr+1)], Nk) //begin void AES256_NI::KeyExpansion(const unsigned char key[4 * Nk]) { //word temp //i = 0 //while (i < Nk) //w[i] = word(key[4*i], key[4*i+1], key[4*i+2], key[4*i+3]) //i = i+1 //end while __m128i work1 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(key)); __m128i work2 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(key + 16)); w128[0] = work1; // work1 = w3 : w2 : w1 : w0 w128[1] = work2; // work2 = w7 : w6 : w5 : w4 //i = Nk //while (i < Nb * (Nr+1)] //temp = w[i-1] //if (i mod Nk = 0) //temp = SubWord(RotWord(temp)) xor Rcon[i/Nk] //else if (Nk > 6 and i mod Nk = 4) //temp = SubWord(temp) //end if //w[i] = w[i-Nk] xor temp //i = i + 1 //end while __m128i t; // f(w) = SubWord(RotWord(w)) xor Rcon // g(w) = SubWord(w) // work1 = w3 : w2 : w1 : w0 // work2 = w7 : w6 : w5 : w4 // w8 = w0^f(w7) // w9 = w1^w8 = w1^w0^f(w7) // w10 = w2^w9 = w2^w1^w0^f(w7) // w11 = w3^w11 = w3^w2^w1^w0^f(w7) #define EXPAND1(n, RCON) \ t = _mm_slli_si128(work1, 4); /* t = w2 : w1 : w0 : 0 */ \ work1 = _mm_xor_si128(work1, t);/* work1 = w3^w2 : w2^w1 : w1^w0 : w0 */ \ t = _mm_slli_si128(work1, 8); /* t = w1^w0 : w0 : 0 : 0 */ \ work1 = _mm_xor_si128(work1, t);/* work1 = w3^w2^w1^w0 : w2^w1^w0 : w1^w0 : w0 */ \ t = _mm_aeskeygenassist_si128(work2, RCON); /* t = f(w7) : - : - : - */ \ t = _mm_shuffle_epi32(t, 0xFF); /* t = f(w7) : f(w7) : f(w7) : f(w7) */ \ work1 = _mm_xor_si128(work1, t);/* work1 = w3^w2^w1^w0^f(w7) : w2^w1^w0^f(w7) : w1^w0^f(w7) : w0^f(w7) */ \ w128[n] = work1; /* work1 = w11 : w10 : w9 : w8 */ EXPAND1(2, 0x01); // work2 = w7 : w6 : w5 : w4 // work1 = w11 : w10 : w9 : w8 // w12 = w4^g(w11) // w13 = w5^w12 = w5^w4^g(w11) // w14 = w6^w13 = w6^w5^w4^g(w11) // w15 = w7^w14 = w7^w6^w5^w4^g(w11) #define EXPAND2(n) \ t = _mm_slli_si128(work2, 4); /* t = w6 : w5 : w4 : 0 */ \ work2 = _mm_xor_si128(work2, t);/* work2 = w7^w6 : w6^w5 : w5^w4 : w4 */ \ t = _mm_slli_si128(work2, 8); /* t = w5^w4 : w4 : 0 : 0 */ \ work2 = _mm_xor_si128(work2, t);/* work2 = w7^w6^w5^w4 : w6^w5^w4 : w5^w4 : w4 */ \ t = _mm_aeskeygenassist_si128(work1, 0); /* t = - : g(w11) : - : - */ \ t = _mm_shuffle_epi32(t, 0xAA); /* t = g(w11) : g(w11) : g(w11) : g(w11) */ \ work2 = _mm_xor_si128(work2, t);/* work2 = w7^w6^w5^w4^g(w11): w6^w5^w4^g(w11): w5^w4^g(w11): w4^g(w11)*/ \ w128[n] = work2; /* work2 = w15 : w14 : w13 : w12 */ EXPAND2(3); // Go on... EXPAND1(4, 0x02); EXPAND2(5); EXPAND1(6, 0x04); EXPAND2(7); EXPAND1(8, 0x08); EXPAND2(9); EXPAND1(10, 0x10); EXPAND2(11); EXPAND1(12, 0x20); EXPAND2(13); EXPAND1(14, 0x40); // Additional process for EqInvCipher if (decrypting) { //for i = 0 step 1 to (Nr+1)*Nb-1 //dw[i] = w[i] //end for dw128[0] = w128[0]; //for round = 1 step 1 to Nr-1 //InvMixColumns(dw[round*Nb, (round+1)*Nb-1]) //end for dw128[1] = _mm_aesimc_si128(w128[1]); dw128[2] = _mm_aesimc_si128(w128[2]); dw128[3] = _mm_aesimc_si128(w128[3]); dw128[4] = _mm_aesimc_si128(w128[4]); dw128[5] = _mm_aesimc_si128(w128[5]); dw128[6] = _mm_aesimc_si128(w128[6]); dw128[7] = _mm_aesimc_si128(w128[7]); dw128[8] = _mm_aesimc_si128(w128[8]); dw128[9] = _mm_aesimc_si128(w128[9]); dw128[10] = _mm_aesimc_si128(w128[10]); dw128[11] = _mm_aesimc_si128(w128[11]); dw128[12] = _mm_aesimc_si128(w128[12]); dw128[13] = _mm_aesimc_si128(w128[13]); dw128[Nr] = w128[Nr]; } //end }