Writeup for bfv and PSN in 2021qwb Final

题目附件:bfv PSN

bfv

题目实现了一个bfv同态加密算法,并且提供了encryption oracle和decryption oracle(option 2)。

大部分同态加密算法都不是CCA安全的,攻击者可以通过decryption oracle来恢复出私钥。

paper: Danger of using fully homomorphic encryption: A look at Microsoft SEAL

bfv加解密过程如下:

image-20210713105002449

攻击者可以往decryption oracle发送

$$ \begin{aligned} c[0] &= 0 \newline c[1] &= \delta = \lfloor q/t \rfloor \end{aligned} $$

解密的结果为私钥s

$$ \begin{aligned} tmp &= [c[0] + c[1]*s]_q = [ \delta * s]_q \newline m &= [round( \delta * s * t/q)]_t = s \end{aligned} $$

image-20210713105324699


但是这一题中的mutual函数并不会返回所有的解密信息,只会判断解密结果的某一位等不等于0。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
def mutual(k, c, s):
    tmp = t * Roundq(c[0] + c[1] * s)
    TMP = tmp.list()
    for i in range(len(TMP)):
        TMP[i] = round(TMP[i] / q)
    tmp2 = Roundt(R(TMP))			# decryption is over here
    if tmp2[min(k, d)] == 0:
        print(True)
    else:
        print(False)

私钥s一共1024位,满足正态分布,取值范围大概率在-7~7之间。

单纯地使用上述方法,只能得到私钥$s_i = 0$的位数,并不能恢复出s的所有位数。

可以灵活地构造一下,发送 $$ \begin{aligned} c[0] &= -guess * \delta \newline c[1] &= \delta \end{aligned} $$ 给decryption oracle,解密结果为

$$ \begin{aligned} tmp &= [c[0] + c[1] * s]_q = [(-\text{guess} + s) * \delta ]_q \newline m &= [round( \delta * (-\text{guess}+s) * t/q)]_t = -\text{guess} + s \end{aligned} $$

guess依次取-7~7,即可得到满足$s_i=\text{guess}$的所有位数,即私钥s的每一位。

这题主要是Q7师傅的思路,tql~


解题流程:

  1. (option 2)获取admin的id_num密文
  2. (option 4)注册用户,直至用户数达到1024
  3. (option 2)发送$c_0 = -\text{guess} * \delta, c_1 = \delta$来获取所有$s_i == \text{guess}$的下标,从而恢复私钥s
  4. (本地)用私钥s对admin的id_num密文进行解密
  5. (option 1)加admin好友
  6. (option 3)拿flag
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
from sage.stats.distributions.discrete_gaussian_integer import DiscreteGaussianDistributionIntegerSampler
from random import randint, getrandbits
import sys

from pwn import *


q = 2 ^ 54
t = 83
T = 3
d = 1024
delta = int(q / t)
sigma = 2
P.<x> = PolynomialRing(ZZ)
f = x ^ d + 1
R.<X> = P.quotient(f)
D = DiscreteGaussianDistributionIntegerSampler(sigma=sigma)


def sample1():
    return R([D() for _ in range(d)])


def sample2():
    return R([randint(0, q - 1) for _ in range(d)])

def sample3(x):
    return [randint(0, T - 1) for _ in range(x)]


def Roundq(a):
    A = a.list()
    for i in range(len(A)):
        A[i] = A[i] % q
        if A[i] > (q / 2):
            A[i] = A[i] - q
    return R(A)


def Roundt(a):
    A = a.list()
    for i in range(len(A)):
        A[i] = A[i] % t
        if A[i] > (t / 2):
            A[i] = A[i] - t
    return R(A)


def keygen():
    s = sample1()
    a = Roundq(sample2())
    e = Roundq(sample1())
    pk = [Roundq(-(a * s + e)), a]
    return s, pk


def encrypt(m):
    u = sample1()
    e1 = sample1()
    e2 = sample1()
    return (Roundq(pk[0] * u + e1 + delta * m), Roundq(pk[1] * u + e2))


def baseT(n, b=T):
    v = []
    while True:
        x = n // b
        y = n % b
        v.append(y)
        if x == 0:
            break
        n = x
    v.reverse()
    return v

def rev_baseT(l):
    s = 0
    for i in l:
        s += i
        s *= 3
    return s

def decrypt(c, s):
    tmp = t * Roundq(c[0] + c[1] * s)
    TMP = tmp.list()
    for i in range(len(TMP)):
        TMP[i] = round(TMP[i] / q)
    tmp2 = Roundt(R(TMP))
    return tmp2.list()


# 0. get admin id_num ciphertext
# 1. register users, len(users) = 1024
# 2. send c0 = -sk_i*delta, c1 = delta to get index of sk==sk_i
# 3. decrypt "admin" ct to get id_num
# 4. add friends admin
# 5. get flag

conn = remote("127.0.0.1", 9999)
# conn = remote("106.15.177.94", 8001)

DEBUG = False
if DEBUG:
    context.log_level = 'debug'
    conn.recvuntil(b"sk:\n")
    real_sk = conn.recvline().strip().decode()
    log.info(f"sk: {real_sk}")

conn.recvuntil(b"4.Regist")

# 0. get admin id_num ciphertext
conn.sendlineafter(b">", b"2")
conn.sendlineafter(b"recv ct?(Y/N)", b"Y")
admin_id_num_ct0 = list(map(int, conn.recvuntil(b"]").decode()[1:-1].split(", ")))
admin_id_num_ct1 = list(map(int, conn.recvuntil(b"]").decode()[2:-1].split(", ")))
# log.info(f"admin_id_num_ct0: {admin_id_num_ct0}\nadmin_id_num_ct1: {admin_id_num_ct1}\n")

conn.sendlineafter(b"continue?(Y/N)", b"N")
for _ in range(5):
    conn.sendlineafter(b"c1:", b"0")
    conn.sendlineafter(b"c2:", b"0")


# 1. register users, len(users) = 1024
for i in range(1024-5):
    conn.sendlineafter(b">", b"4")
    conn.sendlineafter(b"name:", str(i).encode())
log.info("Register over!")


# 2. send c0 = -sk_i*delta, c1 = delta to get index of sk==sk_i
delta=floor(q/t)

sk = [0]*1024
ct1 = delta
ct1 = str(delta).encode()

for sk_i in range(-7, 8): # [-7, 7]
    log.info(f"guess: {sk_i}")
    conn.sendlineafter(b">", b"2")
    conn.sendlineafter(b"recv ct?(Y/N)", b"N")
    ct0 = [-sk_i*delta] * 1024
    ct0 = " ".join(str(num) for num in ct0).encode()
    for index in range(1024):
        conn.sendlineafter(b"c1:", ct0)
        conn.sendlineafter(b"c2:", ct1)
        res = conn.recvline().strip().decode()
        if res == "True":
            sk[index] = sk_i
            print(f"{index:-4d}: {sk_i}")

log.info(f"sk: {sk}")
if DEBUG:
    print(str(sk) == real_sk)
    real_sk = list(map(int, real_sk[1:-1].split(", ")))

# 3. decrypt "admin" ct to get id_num
pt = decrypt([R(admin_id_num_ct0), R(admin_id_num_ct1)], R(sk))

# 4. add friends admin
for i in range(18, 22):
    id_num = rev_baseT(pt[:i-1])
    log.info(f"id_num: {id_num} {int(id_num).bit_length()}\n")
    conn.sendlineafter(b">", b"1")
    conn.sendlineafter(b"name:", b"admin")
    conn.sendlineafter(b"id:", str(id_num).encode())
    res = conn.recvline().strip().decode()
    log.info(f"res: {res}")
    if res == "failed":
        continue
    break

# 5. get flag
conn.sendlineafter(b">", b"3")
conn.sendlineafter(b"name:", b"admin")
conn.sendlineafter(b"message:", b"give me the flag")
flag = conn.recvline()
log.info(f"flag: {flag}")

有概率会失败,因为私钥s的取值范围会超过-7~7(-7~7是权衡了交互次数和概率而取的一个范围)。

PSN

题目实现了一个Pseudo Stochastic Network(我也不知道这是个啥玩意儿)

问题出现在seed = flag + bias上

bias:32bytes

flag:24bytes,且flag格式为flag{.*},已知6bytes

加密时候用的密钥,偶数位key取的是seed的前16bytes,奇数位key取的是seed的后16bytes。

1
key = ((self.seed>>(128 if i%2==0 else 0)) + i)&((1<<128)-1)

bias是给了的,那么实际上偶数位的key我们未知的只有flag的前8bytes,而通过flag格式,我们其实上能够知道这8bytes的前5bytes,未知的就只有末尾的3bytes,爆破一下就有了。

拿到偶数位的key之后,就可以识别出所有偶数位的cipher_list


cipher_list根据seed来选择的,是seed的3进制表示。

seed里面包含了flag的所有信息,只要能解出seed,就能拿到flag。

cipher_list又完全与seed相关,所以如果我们能识别出所有的cipher_list,就能算出seed,算出flag。

cipher_list => seed => flag

我们现在可以识别出所有偶数位的cipher_list,但无法识别出奇数位的cipher_list

但这并不影响,我们观察到题目其实没有限制pt的长度,并且超过len(cipher_list)的部分会被mod。

1
cipher = self.cipher_list[i%len(self.cipher_list)](key)

如果len(cipher_list)是奇数,那么发2*len(cipher_list)长的pt给服务端,偶数位就可以覆盖到cipher_list的所有位数。

len(cipher_list)是seed的三进制表示,而seed的位数又主要由bias决定,因此当$\log_3{bias}$为奇数时,len(cipher_list)就是奇数。

前面一半长度等于len(cipher_list)的pt,可以覆盖到cipher_list下标为0, 2, 4, …, len(cipher_list)-的部分

后面一半长度等于len(cipher_list)的pt,可以覆盖到cipher_list下标为1, 3, 5, …, len(cipher_list)-2的部分

这样,我们就可以识别出所有的cipher_list,进而算出flag。

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
from pwn import *
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
import random
import math

xtime = lambda a: (((a << 1) ^ 0x1B) & 0xFF) if (a & 0x80) else (a << 1)

def text2matrix(text):
    text = int.from_bytes(text,"big")
    matrix = []
    for i in range(16):
        byte = (text >> (8 * (15 - i))) & 0xFF
        if i % 4 == 0:
            matrix.append([byte])
        else:
            matrix[i // 4].append(byte)
    return matrix

def matrix2text(matrix):
    text = 0
    for i in range(4):
        for j in range(4):
            text |= (matrix[i][j] << (120 - 8 * (4 * i + j)))
    text = text.to_bytes(16,"big")
    return text

class backdoorAES:
    def __init__(self, master_key):
        self.change_key(master_key)

    def change_key(self, master_key):
        self.Sbox = (
            0x63, 0x7C, 0x77, 0x7B, 0xF2, 0x6B, 0x6F, 0xC5, 0x30, 0x01, 0x67, 0x2B, 0xFE, 0xD7, 0xAB, 0x76,
            0xCA, 0x82, 0xC9, 0x7D, 0xFA, 0x59, 0x47, 0xF0, 0xAD, 0xD4, 0xA2, 0xAF, 0x9C, 0xA4, 0x72, 0xC0,
            0xB7, 0xFD, 0x93, 0x26, 0x36, 0x3F, 0xF7, 0xCC, 0x34, 0xA5, 0xE5, 0xF1, 0x71, 0xD8, 0x31, 0x15,
            0x04, 0xC7, 0x23, 0xC3, 0x18, 0x96, 0x05, 0x9A, 0x07, 0x12, 0x80, 0xE2, 0xEB, 0x27, 0xB2, 0x75,
            0x09, 0x83, 0x2C, 0x1A, 0x1B, 0x6E, 0x5A, 0xA0, 0x52, 0x3B, 0xD6, 0xB3, 0x29, 0xE3, 0x2F, 0x84,
            0x53, 0xD1, 0x00, 0xED, 0x20, 0xFC, 0xB1, 0x5B, 0x6A, 0xCB, 0xBE, 0x39, 0x4A, 0x4C, 0x58, 0xCF,
            0xD0, 0xEF, 0xAA, 0xFB, 0x43, 0x4D, 0x33, 0x85, 0x45, 0xF9, 0x02, 0x7F, 0x50, 0x3C, 0x9F, 0xA8,
            0x51, 0xA3, 0x40, 0x8F, 0x92, 0x9D, 0x38, 0xF5, 0xBC, 0xB6, 0xDA, 0x21, 0x10, 0xFF, 0xF3, 0xD2,
            0xCD, 0x0C, 0x13, 0xEC, 0x5F, 0x97, 0x44, 0x17, 0xC4, 0xA7, 0x7E, 0x3D, 0x64, 0x5D, 0x19, 0x73,
            0x60, 0x81, 0x4F, 0xDC, 0x22, 0x2A, 0x90, 0x88, 0x46, 0xEE, 0xB8, 0x14, 0xDE, 0x5E, 0x0B, 0xDB,
            0xE0, 0x32, 0x3A, 0x0A, 0x49, 0x06, 0x24, 0x5C, 0xC2, 0xD3, 0xAC, 0x62, 0x91, 0x95, 0xE4, 0x79,
            0xE7, 0xC8, 0x37, 0x6D, 0x8D, 0xD5, 0x4E, 0xA9, 0x6C, 0x56, 0xF4, 0xEA, 0x65, 0x7A, 0xAE, 0x08,
            0xBA, 0x78, 0x25, 0x2E, 0x1C, 0xA6, 0xB4, 0xC6, 0xE8, 0xDD, 0x74, 0x1F, 0x4B, 0xBD, 0x8B, 0x8A,
            0x70, 0x3E, 0xB5, 0x66, 0x48, 0x03, 0xF6, 0x0E, 0x61, 0x35, 0x57, 0xB9, 0x86, 0xC1, 0x1D, 0x9E,
            0xE1, 0xF8, 0x98, 0x11, 0x69, 0xD9, 0x8E, 0x94, 0x9B, 0x1E, 0x87, 0xE9, 0xCE, 0x55, 0x28, 0xDF,
            0x8C, 0xA1, 0x89, 0x0D, 0xBF, 0xE6, 0x42, 0x68, 0x41, 0x99, 0x2D, 0x0F, 0xB0, 0x54, 0xBB, 0x16,
        )
        self.InvSbox = (
            0x52, 0x09, 0x6A, 0xD5, 0x30, 0x36, 0xA5, 0x38, 0xBF, 0x40, 0xA3, 0x9E, 0x81, 0xF3, 0xD7, 0xFB,
            0x7C, 0xE3, 0x39, 0x82, 0x9B, 0x2F, 0xFF, 0x87, 0x34, 0x8E, 0x43, 0x44, 0xC4, 0xDE, 0xE9, 0xCB,
            0x54, 0x7B, 0x94, 0x32, 0xA6, 0xC2, 0x23, 0x3D, 0xEE, 0x4C, 0x95, 0x0B, 0x42, 0xFA, 0xC3, 0x4E,
            0x08, 0x2E, 0xA1, 0x66, 0x28, 0xD9, 0x24, 0xB2, 0x76, 0x5B, 0xA2, 0x49, 0x6D, 0x8B, 0xD1, 0x25,
            0x72, 0xF8, 0xF6, 0x64, 0x86, 0x68, 0x98, 0x16, 0xD4, 0xA4, 0x5C, 0xCC, 0x5D, 0x65, 0xB6, 0x92,
            0x6C, 0x70, 0x48, 0x50, 0xFD, 0xED, 0xB9, 0xDA, 0x5E, 0x15, 0x46, 0x57, 0xA7, 0x8D, 0x9D, 0x84,
            0x90, 0xD8, 0xAB, 0x00, 0x8C, 0xBC, 0xD3, 0x0A, 0xF7, 0xE4, 0x58, 0x05, 0xB8, 0xB3, 0x45, 0x06,
            0xD0, 0x2C, 0x1E, 0x8F, 0xCA, 0x3F, 0x0F, 0x02, 0xC1, 0xAF, 0xBD, 0x03, 0x01, 0x13, 0x8A, 0x6B,
            0x3A, 0x91, 0x11, 0x41, 0x4F, 0x67, 0xDC, 0xEA, 0x97, 0xF2, 0xCF, 0xCE, 0xF0, 0xB4, 0xE6, 0x73,
            0x96, 0xAC, 0x74, 0x22, 0xE7, 0xAD, 0x35, 0x85, 0xE2, 0xF9, 0x37, 0xE8, 0x1C, 0x75, 0xDF, 0x6E,
            0x47, 0xF1, 0x1A, 0x71, 0x1D, 0x29, 0xC5, 0x89, 0x6F, 0xB7, 0x62, 0x0E, 0xAA, 0x18, 0xBE, 0x1B,
            0xFC, 0x56, 0x3E, 0x4B, 0xC6, 0xD2, 0x79, 0x20, 0x9A, 0xDB, 0xC0, 0xFE, 0x78, 0xCD, 0x5A, 0xF4,
            0x1F, 0xDD, 0xA8, 0x33, 0x88, 0x07, 0xC7, 0x31, 0xB1, 0x12, 0x10, 0x59, 0x27, 0x80, 0xEC, 0x5F,
            0x60, 0x51, 0x7F, 0xA9, 0x19, 0xB5, 0x4A, 0x0D, 0x2D, 0xE5, 0x7A, 0x9F, 0x93, 0xC9, 0x9C, 0xEF,
            0xA0, 0xE0, 0x3B, 0x4D, 0xAE, 0x2A, 0xF5, 0xB0, 0xC8, 0xEB, 0xBB, 0x3C, 0x83, 0x53, 0x99, 0x61,
            0x17, 0x2B, 0x04, 0x7E, 0xBA, 0x77, 0xD6, 0x26, 0xE1, 0x69, 0x14, 0x63, 0x55, 0x21, 0x0C, 0x7D,
        )
        self.round_keys = text2matrix(master_key)

        Rcon = (
            0x00, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40,
            0x80, 0x1B, 0x36, 0x6C, 0xD8, 0xAB, 0x4D, 0x9A,
            0x2F, 0x5E, 0xBC, 0x63, 0xC6, 0x97, 0x35, 0x6A,
            0xD4, 0xB3, 0x7D, 0xFA, 0xEF, 0xC5, 0x91, 0x39,
        )
        for i in range(4, 4 * 5):
            self.round_keys.append([])
            if i % 4 == 0:
                byte = self.round_keys[i - 4][0] ^ self.Sbox[self.round_keys[i - 1][1]] ^ Rcon[i // 4]
                self.round_keys[i].append(byte)

                for j in range(1, 4):
                    byte = self.round_keys[i - 4][j] ^ self.Sbox[self.round_keys[i - 1][(j + 1) % 4]]
                    self.round_keys[i].append(byte)
            else:
                for j in range(4):
                    byte = self.round_keys[i - 4][j] ^ self.round_keys[i - 1][j]
                    self.round_keys[i].append(byte)

    def encrypt(self, plaintext):
        self.plain_state = text2matrix(plaintext)

        self.__add_round_key(self.plain_state, self.round_keys[:4])

        for i in range(1, 4):
            self.__round_encrypt(self.plain_state, self.round_keys[4 * i : 4 * (i + 1)])

        self.__sub_bytes(self.plain_state)
        self.__shift_rows(self.plain_state)
        self.__add_round_key(self.plain_state, self.round_keys[16:])

        return matrix2text(self.plain_state)

    def decrypt(self, ciphertext):
        self.cipher_state = text2matrix(ciphertext)

        self.__add_round_key(self.cipher_state, self.round_keys[16:])
        self.__inv_shift_rows(self.cipher_state)
        self.__inv_sub_bytes(self.cipher_state)

        for i in range(3, 0, -1):
            self.__round_decrypt(self.cipher_state, self.round_keys[4 * i : 4 * (i + 1)])

        self.__add_round_key(self.cipher_state, self.round_keys[:4])

        return matrix2text(self.cipher_state)

    def __add_round_key(self, s, k):
        for i in range(4):
            for j in range(4):
                s[i][j] ^= k[i][j]


    def __round_encrypt(self, state_matrix, key_matrix):
        self.__sub_bytes(state_matrix)
        self.__shift_rows(state_matrix)
        self.__mix_columns(state_matrix)
        self.__add_round_key(state_matrix, key_matrix)


    def __round_decrypt(self, state_matrix, key_matrix):
        self.__add_round_key(state_matrix, key_matrix)
        self.__inv_mix_columns(state_matrix)
        self.__inv_shift_rows(state_matrix)
        self.__inv_sub_bytes(state_matrix)

    def __sub_bytes(self, s):
        for i in range(4):
            for j in range(4):
                s[i][j] = self.Sbox[s[i][j]]


    def __inv_sub_bytes(self, s):
        for i in range(4):
            for j in range(4):
                s[i][j] = self.InvSbox[s[i][j]]


    def __shift_rows(self, s):
        s[0][1], s[1][1], s[2][1], s[3][1] = s[1][1], s[2][1], s[3][1], s[0][1]
        s[0][2], s[1][2], s[2][2], s[3][2] = s[2][2], s[3][2], s[0][2], s[1][2]
        s[0][3], s[1][3], s[2][3], s[3][3] = s[3][3], s[0][3], s[1][3], s[2][3]


    def __inv_shift_rows(self, s):
        s[0][1], s[1][1], s[2][1], s[3][1] = s[3][1], s[0][1], s[1][1], s[2][1]
        s[0][2], s[1][2], s[2][2], s[3][2] = s[2][2], s[3][2], s[0][2], s[1][2]
        s[0][3], s[1][3], s[2][3], s[3][3] = s[1][3], s[2][3], s[3][3], s[0][3]

    def __mix_single_column(self, a):
        t = a[0] ^ a[1] ^ a[2] ^ a[3]
        u = a[0]
        a[0] ^= t ^ xtime(a[0] ^ a[1])
        a[1] ^= t ^ xtime(a[1] ^ a[2])
        a[2] ^= t ^ xtime(a[2] ^ a[3])
        a[3] ^= t ^ xtime(a[3] ^ u)


    def __mix_columns(self, s):
        for i in range(4):
            self.__mix_single_column(s[i])


    def __inv_mix_columns(self, s):
        for i in range(4):
            u = xtime(xtime(s[i][0] ^ s[i][2]))
            v = xtime(xtime(s[i][1] ^ s[i][3]))
            s[i][0] ^= u
            s[i][1] ^= v
            s[i][2] ^= u
            s[i][3] ^= v

        self.__mix_columns(s)

class safeCipher:
    def __init__(self, key, alg):
        self.cipher = Cipher(alg(key), modes.ECB())

    def encrypt(self, pt):
        encryptor = self.cipher.encryptor()
        ct = encryptor.update(pt)
        assert encryptor.finalize() == b''
        return ct

    def decrypt(self, ct):
        decryptor = self.cipher.decryptor()
        pt = decryptor.update(ct)
        assert decryptor.finalize() == b''
        return pt

class Camellia(safeCipher):
    def __init__(self, key):
        safeCipher.__init__(self, key, algorithms.Camellia)

class SEED(safeCipher):
    def __init__(self, key):
        safeCipher.__init__(self, key, algorithms.SEED)

HOST = "47.96.164.154"
PORT = 62344
conn = remote(HOST, PORT)
DEBUG = False
if DEBUG:
    context.log_level = "debug"


conn.recvuntil(b"Security bias: ")
bias = int(conn.recvline().strip().decode(), 16)
log.info(f"bias: {bias} {hex(bias)}")
bits = math.ceil(math.log(bias, 3))
if bits & 1 == 0:
    log.info(f"Try again: bias={bits}")
    exit(-1)

conn.sendlineafter(b"choice: ", b"0")
conn.sendlineafter(b"Encrypt: ", b"0"*32*bits*2)

ciphers = bytes.fromhex(conn.recvline().strip().decode())
cipher_list = []
for i in range(0, len(ciphers), 32):
    cipher_list.append(ciphers[i:i+16])



seed = int.from_bytes(b"flag{E4s",'big') + (bias >> 128)
index = []
for i, ci in enumerate(cipher_list):
    key = (seed + 2*i) & ((1<<128)-1)
    key = key.to_bytes(16,"big")
    for alg in [backdoorAES, Camellia, SEED]:
        c = alg(key)
        if c.encrypt(b"\x00"*16) == ci:
            print(i, alg)
            if alg == backdoorAES:
                index.append(0)
            elif alg == Camellia:
                index.append(1)
            elif alg == SEED:
                index.append(2)

if len(index) == 0:
    log.info("Try again!")
    exit(-1)

print(len(index), index)
new_index = []
for i in range(len(index)//2):
    new_index.append(index[i])
    new_index.append(index[len(index)//2+1+i])

new_index.append(index[len(index)//2])
print(len(new_index), new_index)

s = int(''.join(str(n) for n in new_index), 3)
print(s, bias)
print((s-bias).to_bytes(24, 'big'))

等到bias为奇数时就可以了,但是还有有概率会失败,多跑几次就行啦