Bit-slicing AES¶
In this example, we build a circuit of bit-slicing AES. This implementation is based on the white-box AES of BU18.
Operations on vectors¶
import operator
class Vector(list):
ZERO = 0
WIDTH = None
@classmethod
def make(cls, lst):
lst = list(lst)
if cls.WIDTH is not None:
assert len(lst) == cls.WIDTH
return cls(lst)
def split(self, n=2):
assert len(self) % n == 0
w = len(self) // n
return Vector(self.make(self[i:i+w]) for i in range(0, len(self), w))
def rol(self, n=1):
n %= len(self)
return self.make(self[n:] + self[:n])
def ror(self, n=1):
return self.rol(-n)
def __repr__(self):
return "<Vector len=%d list=%r>" % (len(self), list(self))
def flatten(self):
if isinstance(self[0], Vector):
return self[0].concat(*self[1:])
return reduce(operator.add, list(self))
def map(self, f, with_coord=False):
if with_coord:
return self.make(f(i, v) for i, v in enumerate(self))
else:
return self.make(f(v) for v in self)
def __xor__(self, other):
assert isinstance(other, Vector)
assert len(self) == len(other)
return self.make(a ^ b for a, b in zip(self, other))
def __or__(self, other):
assert isinstance(other, Vector)
assert len(self) == len(other)
return self.make(a | b for a, b in zip(self, other))
def __and__(self, other):
assert isinstance(other, Vector)
assert len(self) == len(other)
return self.make(a & b for a, b in zip(self, other))
def set(self, x, val):
return self.make(v if i != x else val for i, v in enumerate(self))
Operations on matrices (Rect)¶
class Rect(object):
def __init__(self, vec, h=None, w=None):
assert h or w
if h:
w = len(vec) // h
elif w:
h = len(vec) // w
assert w * h == len(vec)
self.w, self.h = w, h
self.lst = []
for i in range(0, len(vec), w):
self.lst.append(list(vec[i:i+w]))
@classmethod
def from_rect(cls, rect):
self = object.__new__(cls)
self.lst = rect
self.h = len(rect)
self.w = len(rect[0])
return self
def __getitem__(self, pos):
y, x = pos
return self.lst[y][x]
def __setitem__(self, pos, val):
y, x = pos
self.lst[y][x] = val
def row(self, i):
return Vector(self.lst[i])
def col(self, i):
return Vector(self.lst[y][i] for y in range(self.h))
def set_row(self, y, vec):
for x in range(self.w):
self.lst[y][x] = vec[x]
return self
def set_col(self, x, vec):
for y in range(self.h):
self.lst[y][x] = vec[y]
return self
def apply(self, f, with_coord=False):
for y in range(self.h):
if with_coord:
self.lst[y] = [f(y, x, v) for x, v in enumerate(self.lst[y])]
else:
self.lst[y] = list(map(f, self.lst[y]))
return self
def apply_row(self, x, func):
return self.set_row(x, func(self.row(x)))
def apply_col(self, x, func):
return self.set_col(x, func(self.col(x)))
def flatten(self):
lst = []
for v in self.lst:
lst += v
return Vector(lst)
def zipwith(self, f, other):
assert isinstance(other, Rect)
assert self.h == other.h
assert self.w == other.w
return Rect(
[f(a, b) for a, b in zip(self.flatten(), other.flatten())],
h=self.h, w=self.w
)
def transpose(self):
rect = [[self.lst[y][x] for y in range(self.h)] for x in range(self.w)]
return Rect.from_rect(rect=rect)
def __repr__(self):
return "<Rect %dx%d>" % (self.h, self.w)
Bit-slicing implementation of Sbox¶
def Not(x):
return 1^x
def GF_SQ_2(A): return A[1], A[0]
def GF_SCLW_2(A): return A[1], A[1] ^ A[0]
def GF_SCLW2_2(A): return A[1] ^ A[0], A[0]
def GF_MULS_2(A, ab, B, cd):
abcd = (ab & cd)
p = ((A[1] & B[1])) ^ abcd
q = ((A[0] & B[0])) ^ abcd
return q, p
def GF_MULS_SCL_2(A, ab, B, cd):
t = (A[0] & B[0])
p = ((ab & cd)) ^ t
q = ((A[1] & B[1])) ^ t
return q, p
def XOR_LIST(a, b):
return [a ^ b for a, b in zip(a, b)]
def NotOr(a, b):
# return Not(a | b)
return Not(a) & Not(b)
def GF_INV_4(A):
a = A[2:4]
b = A[0:2]
sa = a[1] ^ a[0]
sb = b[1] ^ b[0]
ab = GF_MULS_2(a, sa, b, sb)
ab2 = GF_SQ_2(XOR_LIST(a, b))
ab2N = GF_SCLW2_2(ab2)
d = GF_SQ_2(XOR_LIST(ab, ab2N))
c = [
NotOr(sa, sb) ^ (Not(a[0] & b[0])),
NotOr(a[1], b[1]) ^ (Not(sa & sb)),
]
sd = d[1] ^ d[0]
p = GF_MULS_2(d, sd, b, sb)
q = GF_MULS_2(d, sd, a, sa)
return q + p
def GF_SQ_SCL_4(A):
a = A[2:4]
b = A[0:2]
ab2 = GF_SQ_2(a ^ b)
b2 = GF_SQ_2(b)
b2N2 = GF_SCLW_2(b2)
return b2N2 + ab2
def GF_MULS_4(A, a, Al, Ah, aa, B, b, Bl, Bh, bb):
ph = GF_MULS_2(A[2:4], Ah, B[2:4], Bh)
pl = GF_MULS_2(A[0:2], Al, B[0:2], Bl)
p = GF_MULS_SCL_2(a, aa, b, bb)
return XOR_LIST(pl, p) + XOR_LIST(ph, p) #(pl ^ p), (ph ^ p)
def GF_INV_8(A):
a = A[4:8]
b = A[0:4]
sa = XOR_LIST(a[2:4], a[0:2])
sb = XOR_LIST(b[2:4], b[0:2])
al = a[1] ^ a[0]
ah = a[3] ^ a[2]
aa = sa[1] ^ sa[0]
bl = b[1] ^ b[0]
bh = b[3] ^ b[2]
bb = sb[1] ^ sb[0]
c1 = (ah & bh)
c2 = (sa[0] & sb[0])
c3 = (aa & bb)
c = [
(NotOr(a[0] , b[0] ) ^ ((al & bl))) ^ ((sa[1] & sb[1])) ^ Not(c2), #0
(NotOr(al , bl ) ^ (Not(a[1] & b[1]))) ^ c2 ^ c3 , #1
(NotOr(sa[1], sb[1]) ^ (Not(a[2] & b[2]))) ^ c1 ^ c2 , #2
(NotOr(sa[0], sb[0]) ^ (Not(a[3] & b[3]))) ^ c1 ^ c3 , #3
]
d = GF_INV_4(c)
sd = XOR_LIST(d[2:4], d[0:2])
dl = d[1] ^ d[0]
dh = d[3] ^ d[2]
dd = sd[1] ^ sd[0]
p = GF_MULS_4(d, sd, dl, dh, dd, b, sb, bl, bh, bb)
q = GF_MULS_4(d, sd, dl, dh, dd, a, sa, al, ah, aa)
return q + p
def MUX21I(A, B, s): #return ((~A & s) ^ (~B & ~s)
return Not(A if s else B)
def SELECT_NOT_8( A, B, s):
Q = [None] * 8
for i in range(8):
Q[i] = MUX21I(A[i], B[i], s)
return Q
def Sbox(A, encrypt):
R1 = A[7] ^ A[5]
R2 = A[7] ^ Not(A[4])
R3 = A[6] ^ A[0]
R4 = A[5] ^ Not(R3)
R5 = A[4] ^ R4
R6 = A[3] ^ A[0]
R7 = A[2] ^ R1
R8 = A[1] ^ R3
R9 = A[3] ^ R8
B = [None] * 8
B[7] = R7 ^ Not(R8)
B[6] = R5
B[5] = A[1] ^ R4
B[4] = R1 ^ Not(R3)
B[3] = A[1]^ R2 ^ R6
B[2] = Not( A[0])
B[1] = R4
B[0] = A[2] ^ Not(R9)
Y = [None] * 8
Y[7] = R2
Y[6] = A[4] ^ R8
Y[5] = A[6] ^ A[4]
Y[4] = R9
Y[3] = A[6] ^ Not(R2)
Y[2] = R7
Y[1] = A[4] ^ R6
Y[0] = A[1] ^ R5
Z = SELECT_NOT_8(B, Y, encrypt)
C = GF_INV_8(Z)
T1 = C[7] ^ C[3]
T2 = C[6] ^ C[4]
T3 = C[6] ^ C[0]
T4 = C[5] ^ Not(C[3])
T5 = C[5] ^ Not(T1)
T6 = C[5] ^ Not(C[1])
T7 = C[4] ^ Not(T6)
T8 = C[2] ^ T4
T9 = C[1] ^ T2
T10 = T3 ^ T5
D = [None] * 8
D[7] = T4
D[6] = T1
D[5] = T3
D[4] = T5
D[3] = T2 ^ T5
D[2] = T3 ^ T8
D[1] = T7
D[0] = T9
X = [None] * 8
X[7] = C[4] ^ Not(C[1])
X[6] = C[1] ^ T10
X[5] = C[2] ^ T10
X[4] = C[6] ^ Not(C[1])
X[3] = T8 ^ T9
X[2] = C[7] ^ Not(T7)
X[1] = T6
X[0] = Not(C[2])
return SELECT_NOT_8(D, X, encrypt)
def bitSbox(A, inverse=False):
res = Sbox(A[::-1], encrypt=1-inverse)[::-1]
return res
Bit-slicing implementations of ShiftRow and MixColumn¶
def ShiftRow(row, nr, inverse=False):
if inverse:
nr = -nr
off = nr % 4
return row[off:] + row[:off]
def MixColumn(col, inverse=False):
res = [[0] * 8 for _ in range(4)]
table = MCi_TABLE if inverse else MC_TABLE
for yi in range(4):
for yj in range(8):
y = yi * 8 + yj
for x in table[y]:
xi, xj = divmod(x, 8)
res[yi][yj] ^= col[xi][xj]
return res
# y -> set of x indices to xor
MC_TABLE = [{1, 8, 9, 16, 24}, {2, 9, 10, 17, 25}, {3, 10, 11, 18, 26}, {0, 4, 8, 11, 12, 19, 27}, {0, 5, 8, 12, 13, 20, 28}, {6, 13, 14, 21, 29}, {0, 7, 8, 14, 15, 22, 30}, {0, 8, 15, 23, 31}, {0, 9, 16, 17, 24}, {1, 10, 17, 18, 25}, {2, 11, 18, 19, 26}, {3, 8, 12, 16, 19, 20, 27}, {4, 8, 13, 16, 20, 21, 28}, {5, 14, 21, 22, 29}, {6, 8, 15, 16, 22, 23, 30}, {7, 8, 16, 23, 31}, {0, 8, 17, 24, 25}, {1, 9, 18, 25, 26}, {2, 10, 19, 26, 27}, {3, 11, 16, 20, 24, 27, 28}, {4, 12, 16, 21, 24, 28, 29}, {5, 13, 22, 29, 30}, {6, 14, 16, 23, 24, 30, 31}, {7, 15, 16, 24, 31}, {0, 1, 8, 16, 25}, {1, 2, 9, 17, 26}, {2, 3, 10, 18, 27}, {0, 3, 4, 11, 19, 24, 28}, {0, 4, 5, 12, 20, 24, 29}, {5, 6, 13, 21, 30}, {0, 6, 7, 14, 22, 24, 31}, {0, 7, 15, 23, 24}]
MCi_TABLE = [{1, 2, 3, 8, 9, 11, 16, 18, 19, 24, 27}, {0, 2, 3, 4, 8, 9, 10, 12, 16, 17, 19, 20, 24, 25, 28}, {1, 3, 4, 5, 8, 9, 10, 11, 13, 17, 18, 20, 21, 24, 25, 26, 29}, {2, 4, 5, 6, 8, 9, 10, 11, 12, 14, 16, 18, 19, 21, 22, 25, 26, 27, 30}, {1, 2, 5, 6, 7, 10, 12, 13, 15, 16, 17, 18, 20, 22, 23, 24, 26, 28, 31}, {1, 6, 7, 8, 9, 13, 14, 17, 21, 23, 24, 25, 29}, {2, 7, 8, 9, 10, 14, 15, 16, 18, 22, 25, 26, 30}, {0, 1, 2, 8, 10, 15, 17, 18, 23, 26, 31}, {0, 3, 9, 10, 11, 16, 17, 19, 24, 26, 27}, {0, 1, 4, 8, 10, 11, 12, 16, 17, 18, 20, 24, 25, 27, 28}, {0, 1, 2, 5, 9, 11, 12, 13, 16, 17, 18, 19, 21, 25, 26, 28, 29}, {1, 2, 3, 6, 10, 12, 13, 14, 16, 17, 18, 19, 20, 22, 24, 26, 27, 29, 30}, {0, 2, 4, 7, 9, 10, 13, 14, 15, 18, 20, 21, 23, 24, 25, 26, 28, 30, 31}, {0, 1, 5, 9, 14, 15, 16, 17, 21, 22, 25, 29, 31}, {1, 2, 6, 10, 15, 16, 17, 18, 22, 23, 24, 26, 30}, {2, 7, 8, 9, 10, 16, 18, 23, 25, 26, 31}, {0, 2, 3, 8, 11, 17, 18, 19, 24, 25, 27}, {0, 1, 3, 4, 8, 9, 12, 16, 18, 19, 20, 24, 25, 26, 28}, {1, 2, 4, 5, 8, 9, 10, 13, 17, 19, 20, 21, 24, 25, 26, 27, 29}, {0, 2, 3, 5, 6, 9, 10, 11, 14, 18, 20, 21, 22, 24, 25, 26, 27, 28, 30}, {0, 1, 2, 4, 6, 7, 8, 10, 12, 15, 17, 18, 21, 22, 23, 26, 28, 29, 31}, {1, 5, 7, 8, 9, 13, 17, 22, 23, 24, 25, 29, 30}, {0, 2, 6, 9, 10, 14, 18, 23, 24, 25, 26, 30, 31}, {1, 2, 7, 10, 15, 16, 17, 18, 24, 26, 31}, {0, 1, 3, 8, 10, 11, 16, 19, 25, 26, 27}, {0, 1, 2, 4, 8, 9, 11, 12, 16, 17, 20, 24, 26, 27, 28}, {0, 1, 2, 3, 5, 9, 10, 12, 13, 16, 17, 18, 21, 25, 27, 28, 29}, {0, 1, 2, 3, 4, 6, 8, 10, 11, 13, 14, 17, 18, 19, 22, 26, 28, 29, 30}, {2, 4, 5, 7, 8, 9, 10, 12, 14, 15, 16, 18, 20, 23, 25, 26, 29, 30, 31}, {0, 1, 5, 6, 9, 13, 15, 16, 17, 21, 25, 30, 31}, {0, 1, 2, 6, 7, 8, 10, 14, 17, 18, 22, 26, 31}, {0, 2, 7, 9, 10, 15, 18, 23, 24, 25, 26}]
Bit-slicing implementation of Key Schedule¶
Rcon = [0x8d, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36,
0x6c, 0xd8, 0xab, 0x4d, 0x9a, 0x2f, 0x5e, 0xbc, 0x63, 0xc6, 0x97,
0x35, 0x6a, 0xd4, 0xb3, 0x7d, 0xfa, 0xef, 0xc5, 0x91, 0x39, 0x72,
0xe4, 0xd3, 0xbd, 0x61, 0xc2, 0x9f, 0x25, 0x4a, 0x94, 0x33, 0x66,
0xcc, 0x83, 0x1d, 0x3a, 0x74, 0xe8, 0xcb, 0x8d, 0x01, 0x02, 0x04,
0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36, 0x6c, 0xd8, 0xab, 0x4d,
0x9a, 0x2f, 0x5e, 0xbc, 0x63, 0xc6, 0x97, 0x35, 0x6a, 0xd4, 0xb3,
0x7d, 0xfa, 0xef, 0xc5, 0x91, 0x39, 0x72, 0xe4, 0xd3, 0xbd, 0x61,
0xc2, 0x9f, 0x25, 0x4a, 0x94, 0x33, 0x66, 0xcc, 0x83, 0x1d, 0x3a,
0x74, 0xe8, 0xcb, 0x8d, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40,
0x80, 0x1b, 0x36, 0x6c, 0xd8, 0xab, 0x4d, 0x9a, 0x2f, 0x5e, 0xbc,
0x63, 0xc6, 0x97, 0x35, 0x6a, 0xd4, 0xb3, 0x7d, 0xfa, 0xef, 0xc5,
0x91, 0x39, 0x72, 0xe4, 0xd3, 0xbd, 0x61, 0xc2, 0x9f, 0x25, 0x4a,
0x94, 0x33, 0x66, 0xcc, 0x83, 0x1d, 0x3a, 0x74, 0xe8, 0xcb, 0x8d,
0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36, 0x6c,
0xd8, 0xab, 0x4d, 0x9a, 0x2f, 0x5e, 0xbc, 0x63, 0xc6, 0x97, 0x35,
0x6a, 0xd4, 0xb3, 0x7d, 0xfa, 0xef, 0xc5, 0x91, 0x39, 0x72, 0xe4,
0xd3, 0xbd, 0x61, 0xc2, 0x9f, 0x25, 0x4a, 0x94, 0x33, 0x66, 0xcc,
0x83, 0x1d, 0x3a, 0x74, 0xe8, 0xcb, 0x8d, 0x01, 0x02, 0x04, 0x08,
0x10, 0x20, 0x40, 0x80, 0x1b, 0x36, 0x6c, 0xd8, 0xab, 0x4d, 0x9a,
0x2f, 0x5e, 0xbc, 0x63, 0xc6, 0x97, 0x35, 0x6a, 0xd4, 0xb3, 0x7d,
0xfa, 0xef, 0xc5, 0x91, 0x39, 0x72, 0xe4, 0xd3, 0xbd, 0x61, 0xc2,
0x9f, 0x25, 0x4a, 0x94, 0x33, 0x66, 0xcc, 0x83, 0x1d, 0x3a, 0x74,
0xe8, 0xcb ]
def tobin(x, n):
return tuple(map(int, bin(x).lstrip("0b").rjust(n, "0")))
def c8(c):
return Vector(tobin(c, 8))
BitRcon = list(map(c8, Rcon))
def ks_rotate(word):
return word[1:] + word[:1]
def ks_core(word, iteration):
word = word.rol(1)
word = word.map(lambda b: Vector(bitSbox(b)))
word = word.set(0, word[0] ^ BitRcon[iteration])
return word
def KS_round(kstate, rno):
t = ks_core(kstate.col(3), rno+1)
kstate.apply_col(0, lambda c: c ^ t)
t = kstate.col(0)
kstate.apply_col(1, lambda c: c ^ t)
t = kstate.col(1)
kstate.apply_col(2, lambda c: c ^ t)
t = kstate.col(2)
kstate.apply_col(3, lambda c: c ^ t)
return kstate
Bit-slicing AES¶
def BitAES(plaintext, key, rounds=10):
bx = Vector(plaintext).split(16)
bk = Vector(key).split(16)
state = Rect(bx, w=4, h=4).transpose()
kstate = Rect(bk, w=4, h=4).transpose()
for rno in range(rounds):
state = AK(state, kstate)
state = SB(state)
state = SR(state)
if rno < rounds-1:
state = MC(state)
kstate = KS(kstate, rno)
state = AK(state, kstate)
state = state.transpose()
kstate = kstate.transpose()
bits = sum( map(list, state.flatten()), [])
kbits = sum( map(list, kstate.flatten()), [])
return bits, kbits
def AK(state, kstate):
return state.zipwith(lambda a, b: a ^ b, kstate)
def SB(state, inverse=False):
return state.apply(lambda v: Vector(bitSbox(v, inverse=inverse)))
def SR(state, inverse=False):
for y in range(4):
state.apply_row(y, lambda row: ShiftRow(row, y, inverse=inverse))
return state
def MC(state, inverse=False):
for x in range(4):
state.apply_col(x, lambda v: list(map(Vector, MixColumn(v))))
return state
def KS(kstate, rno):
return KS_round(kstate, rno)
Build a circuit for bit-slicing AES¶
from circkit.boolean import BooleanCircuit
C = BooleanCircuit()
pt = C.add_inputs(n=128, format="m%d")
k = C.add_inputs(n=128, format="k%d")
ct, k10 = BitAES(pt, k, rounds=10)
C.add_output(ct)
Then we can evaluate the circuit to test its correctness.
def str2bin(s):
return list(map(int, "".join(bin(ord(c))[2:].zfill(8) for c in s)))
def hex2bin(s):
return list(map(int, "".join(bin(c)[2:].zfill(8) for c in bytes.fromhex(s))))
def bin2hex(s):
assert len(s) % 8 == 0
v = int("".join(map(str, s)), 2)
v = ("%x" %v).zfill(len(s) // 4)
return v
# Expected ciphertext: 69c4e0d86a7b0430d8cdb78070b4c55a
# See the AES documentation of NIST: https://nvlpubs.nist.gov/nistpubs/fips/nist.fips.197.pdf
MSG = "00112233445566778899aabbccddeeff"
KEY = "000102030405060708090a0b0c0d0e0f"
inp = hex2bin(MSG) + hex2bin(KEY)
out = C.evaluate(inp)
out = bin2hex(out)
print(f"Ciphertext: {out}")
Ciphertext: 69c4e0d86a7b0430d8cdb78070b4c55a
For a white-box implementation, we can treat the key as 128 constants of bits instead of 128 input nodes.
C = BooleanCircuit()
pt = C.add_inputs(n=128, format="m%d")
k = [C.add_const(b) for b in hex2bin(KEY)]
ct, k10 = BitAES(pt, k, rounds=10)
C.add_output(ct)
inp = hex2bin(MSG)
out = C.evaluate(inp)
out = bin2hex(out)
print(f"Ciphertext: {out}")
Ciphertext: 69c4e0d86a7b0430d8cdb78070b4c55a