吾爱破解 - 52pojie.cn

 找回密码
 注册[Register]

QQ登录

只需一步,快速开始

查看: 4123|回复: 30
收起左侧

[Python 原创] 非线性组合序列的快速相关攻击算法A和算法B

  [复制链接]
samiya 发表于 2023-7-8 17:06
自己做密码算法,写的很垃圾,需要拿走
[Python] 纯文本查看 复制代码
import math
from scipy.special import comb
from lfsr import LFSR

#计算1的位数
def bit_count(val):
    cnt = 0
    while val:
        cnt += 1
        val = val & (val - 1)
    return cnt

#汉明距离变换
hamming = {15 : [[]] * 16, 16 : [[]] * 17}
for i in range(1, 1 << 16):
    cnt = bit_count(i)
    hamming[16][cnt].append(i)
    if i < (1 << 15):
        hamming[15][cnt].append(i)


class MS:
    def __init__(self, mask, n, z, p):
        self.mask = mask    #抽头掩码
        self.n = n          #级数
        self.z = list(z)    #密钥流序列
        self.p = p          #相关概率
        self.len_z = len(z) #密钥流长度
        self.t = 0          #抽头数
        for i in range(n):
            if (1 << i) & self.mask:
                self.t += 1
        self.m = round(self.M())        #平均每一位方程数
        self.s_init = self.S(self.t)    #每一位正确的后验概率s

    @staticmethod
    def bit_stream_to_int(a):
        return int(''.join(map(str, a)), 2)

    def M(self):
        return math.log2(self.len_z / (2 * self.n)) * (self.t + 1)

    def S(self, t):
        if t == 1:
            return self.p
        return self.p * self.S(t - 1) + (1 - self.p) * (1 - self.S(t - 1))
    
    #生成校验等式
    def get_eq(self):
        tap = [self.n]
        for i in range(self.n):
            if (self.mask >> i) & 1:
                tap.append(self.n - i - 1)
        tap.reverse()
        eqs = [tap]
        while True:
            if (tap[-1] << 1) >= self.len_z:
                break
            tmp = tap.copy()
            for i, val in enumerate(tmp):
                tmp[i] = val << 1
            eqs.append(tmp)
            tap = tmp
        return eqs

    #算法A得到loc位正确的概率
    def calc_eq(self, eqs, loc):
        shift_eqs = []
        for eq in eqs:
            for pos in eq:
                offset = loc - pos
                if eq[0] + offset < 0 or eq[-1] + offset >= self.len_z:
                    continue
                shift_eqs.append([i + offset for i in eq])
        m = len(shift_eqs)
        if m == 0:
            return 0, 0, 0
        h = 0
        for eq in shift_eqs:
            # print(eq)
            xor_sum = 0
            for i in eq:
                xor_sum ^= self.z[i]
            if xor_sum == 0:
                h += 1
        p1 = comb(m, h) * pow(self.s_init, h) * pow(1 - self.s_init, m - h)
        p0 = comb(m, h) * pow(self.s_init, m - h) * pow(1 - self.s_init, h)
        return m, h, p1 / (p1 + p0)

    #生成矩阵
    def gen_linear_eq(self):
        length = max(self.len_z, self.n)
        tap = []
        for i in range(self.n):
            if (self.mask >> i) & 1:
                tap.append(i + 1)
        eqs = []
        for i in range(self.n):
            eqs.append(1 << i)
        for i in range(self.n, length):
            res = 0
            for j in tap:
                res ^= eqs[i - j]
            eqs.append(res)
        return eqs

    #根据生成矩阵解方程
    @staticmethod
    def solve(assume, n):
        eq_len = len(assume)
        mat = []
        for i in range(eq_len):
            mat.append([0] * n)
        b = [0] * eq_len
        for i in range(eq_len):
            b[i] = assume[i][1]
            for j in range(n):
                mat[i][j] = (assume[i][0] >> j) & 1
        for i in range(n):
            tmp = -1
            for j in range(i, eq_len):
                if mat[j][i]:
                    tmp = j
                    break
            if tmp == -1:
                return []
            mat[tmp], mat[i] = mat[i], mat[tmp]
            b[tmp], b[i] = b[i], b[tmp]
            for j in range(eq_len):
                if not mat[j][i] or i == j:
                    continue
                b[j] ^= b[i]
                for k in range(i, n):
                    mat[j][k] ^= mat[i][k]
        if not any(mat[n - 1]):
            return []
        # print(b[:n])
        return b[:n]

    #计算初始状态并验证
    def get_init_stat(self, locs, linear_eq):
        assume = [(linear_eq[x[0]], x[1]) for x in locs]
        b = []
        idx = self.n
        # print("----- try solve equations -----")
        while not b:
            b = MS.solve(assume[:idx], self.n)
            idx += 1
        # print("----- solve success -----")
        stat = MS.bit_stream_to_int(b)

        # print("----- genrate original LFSR -----")
        l = LFSR(stat, self.mask, self.n)
        for i in range(self.n):
            l.step_back()
        init_stat = l.init
        # print("init:", init_stat)
        # print("----- genrate original LFSR finished -----")

        same_cnt = 0
        for i in range(self.len_z):
            same_cnt += int(self.z[i] == l.next())
        rate = same_cnt / self.len_z
        if abs(rate - self.p) < 0.05:
            return init_stat
        else:
            #hamming变换
            for i in range(self.n + 1):
                for filp in hamming[self.n][i]:
                    change = init_stat ^ filp
                    l.init = change
                    cnt = 0
                    for i in range(self.len_z):
                        cnt += int(self.z[i] == l.next())
                    rate = cnt / self.len_z
                    if abs(rate - self.p) < 0.05:
                        return change
        return init_stat

    #算法A攻击函数
    def crackA(self):
        eqs = self.get_eq()
        # print("----- select candidates -----")
        candidates = []
        for i in range(self.len_z):
            m, h, p_star = self.calc_eq(eqs, i)
            if p_star > 0.5:
                candidates.append((p_star, i, m, h))
        candidates.sort(reverse=True)
        # candidates = candidates[:2*self.n]
        # print(candidates[:5])
        # print("----- select candidates finished -----")
        linear_eq = self.gen_linear_eq()
        locs = [(cand[1], self.z[cand[1]]) for cand in candidates]
        return self.get_init_stat(locs, linear_eq)

    #计算s_init
    @staticmethod
    def var_S(var_p, t):
        assert(t == len(var_p))
        if t == 1:
            return var_p[0]
        s = MS.var_S(var_p[:-1], t - 1)
        return var_p[-1] * s + (1 - var_p[-1]) * (1 - s)

    def Q(self, h):
        res = 0
        for i in range(h + 1):
            res += comb(self.m, i) * (self.p_false(i) + self.p_true(i))
        return res

    def I(self, h):
        res = 0
        for i in range(h + 1):
            res += comb(self.m, i) * (self.p_false(i) - self.p_true(i))
        return res

    def p_true(self, h):
        return self.p * pow(self.s_init, h) * pow(1 - self.s_init, self.m - h)

    def p_false(self, h):
        return (1 - self.p) * pow(1 - self.s_init, h) * pow(self.s_init, self.m - h)

    def p_update(self, h):
        t = self.p_true(h)
        f = self.p_false(h)
        return t / (t + f)

    def calc_h_max(self):
        h_max = 0
        I_max = 0
        for i in range(self.m + 1):
            newI = self.I(i)
            if newI > I_max:
                I_max = newI
                h_max = i
        return h_max

    #计算迭代后验概率的4个部分
    def parcheck(self, eqs, arr_p):
        poly = self.t + 1
        s = [[1, 1, 1, 1] for i in range(self.len_z)]
        for eq in eqs:
            curtaps = eq.copy()
            offset = self.len_z - eq[-1]
            # print('\n' , offset , '\n')
            for i in range(offset):
                # print(i , ' ')
                xor_sum = 0
                for tap in curtaps:
                    xor_sum ^= self.z[tap]
                for j in range(poly):
                    var_p = [arr_p[curtaps[k]] for k in range(poly) if k != j]
                    cur_s = MS.var_S(var_p, self.t)
                    cur_bit = curtaps[j]
                    if xor_sum == 0:
                        s[cur_bit][0] *= cur_s
                        s[cur_bit][1] *= 1 - cur_s
                    else:
                        s[cur_bit][2] *= cur_s
                        s[cur_bit][3] *= 1 - cur_s
                for j in range(poly):
                    curtaps[j] += 1
        return s

    #计算迭代后验概率
    @staticmethod
    def var_p_update(p, arr_s):
        t = p * arr_s[0] * arr_s[3]
        div = t + (1 - p) * arr_s[1] * arr_s[2]
        #问题出在这里
        if div == 0:
            return 1
        return t / (t + (1 - p) * arr_s[1] * arr_s[2])

    #检验是否满足校验方程
    @staticmethod
    def check(eq, z):
        offset = len(z) - eq[-1]
        for i in range((offset)):
            xor_sum = 0
            for tap in eq:
                xor_sum ^= z[tap]
            if xor_sum:
                return False
            for j, val in enumerate(eq):
                eq[j] = val + 1
        return True

    #算法B攻击函数
    def crackB(self):
        prim_z = self.z.copy()
        eqs = self.get_eq()
        h_max = self.calc_h_max()
        p_thr = (self.p_update(h_max) + self.p_update(h_max + 1)) / 2
        N_thr = self.Q(h_max) * self.len_z
        # print("------------------------------------------")
        # print("Round\tIteration\t  N_w")
        # print("------------------------------------------")
        for r in range(1, 1000):
            arr_p = [self.p for i in range(self.len_z)]
            for iter in range(1, 6):
                arr_ss = self.parcheck(eqs, arr_p)
                for i in range(self.len_z):
                    arr_p[i] = MS.var_p_update(arr_p[i], arr_ss[i])
                N_w = 0
                for i in range(self.len_z):
                    if arr_p[i] < p_thr:
                        N_w += 1
                # print(r, '\t', iter, '\t\t', N_w)
                if N_w >= N_thr:
                    break
            cnt_filp_bit = 0
            for pos in range(self.len_z):
                if arr_p[pos] < p_thr:
                    self.z[pos] ^= 1
                    cnt_filp_bit += 1
            # print("------------------------------------------")
            if MS.check(eqs[0][:], self.z) or cnt_filp_bit == 0:
                stat = MS.bit_stream_to_int(self.z[:self.n])
                l = LFSR(stat, self.mask, self.n)
                for i in range(self.n):
                    l.step_back()
                self.z = prim_z
                return l.init
        return list() 

免费评分

参与人数 11威望 +2 吾爱币 +114 热心值 +10 收起 理由
SINCERLY + 1 + 1 欢迎分析讨论交流,吾爱破解论坛有你更精彩!
theStyx + 2 + 1 用心讨论,共获提升!
Cofei430 + 1 + 1 谢谢@Thanks!
allspark + 1 + 1 用心讨论,共获提升!
初七的果子狸 + 2 + 1 谢谢@Thanks!
timeslover + 3 谢谢@Thanks!
sanyv + 1 我很赞同!
fengbolee + 2 + 1 欢迎分析讨论交流,吾爱破解论坛有你更精彩!
苏紫方璇 + 2 + 100 + 1 首贴鼓励
motto + 1 + 1 用心讨论,共获提升!
p9rsu9 + 1 + 1 感谢发布原创作品,吾爱破解论坛因你更精彩!

查看全部评分

发帖前要善用论坛搜索功能,那里可能会有你要找的答案或者已经有人发布过相同内容了,请勿重复发帖。

motto 发表于 2023-7-10 00:12
容我先理解一下
tzblue 发表于 2023-7-10 06:08
感谢分享!
只会C、C++,还好里面的语法规则看得懂。
头像被屏蔽
mmattic 发表于 2023-7-10 11:03
zhengsg5 发表于 2023-7-10 12:58
谢谢分享~~~
Marco. 发表于 2023-7-10 13:30
学习学习
DQQQQQ 发表于 2023-7-10 15:14
这个世界最难得不是看不懂,而是都是专业术语。。。确没有解释
言念君子 发表于 2023-7-10 23:01
时间复杂度有点高了
ruixingzhe 发表于 2023-7-11 13:19
先容我学学密码学…感谢大佬分享
Mint111 发表于 2023-7-11 17:22
让我先学学密码学这里太难懂了
您需要登录后才可以回帖 登录 | 注册[Register]

本版积分规则

返回列表

RSS订阅|小黑屋|处罚记录|联系我们|吾爱破解 - LCG - LSG ( 京ICP备16042023号 | 京公网安备 11010502030087号 )

GMT+8, 2024-12-4 01:34

Powered by Discuz!

Copyright © 2001-2020, Tencent Cloud.

快速回复 返回顶部 返回列表