0

我试图让 Shamir 算法的这个实现接受字符串而不仅仅是数字。我可以轻松地将字符串转换为数字,ShamirEncoder.int_from_string()但我不知道如何将字节转换回字符串,因为输出是通过该int_from_string()方法的数字。

ShamirEncoder.string_from_int()是我在另一个实现中发现的,但它不适用于这个实现。

import random
from decimal import Decimal
from mod import Mod

FIELD_SIZE_SMALL: int = 2 ** 10
FIELD_SIZE_MEDIUM: int = 2 ** 30
FIELD_SIZE_128BITS: int = 2 ** 127 - 1
FIELD_SIZE_512BITS: int = 2 ** 511 - 1
FIELD_SIZE_HUGE: int = 2 ** 521 - 1


class ShamirEncoder:
    shares: list

    def __init__(self, nshares: int, minimum: int, secrets: int):
        self.n = nshares
        self.m = minimum
        self.secret = secrets
        self.field_size = self.find_best_prime_for_secret(self.secret)

    @staticmethod
    def polynom(x: int, coeffs: list):
        return sum([x ** (len(coeffs) - i - 1) * coeffs[i] for i in range(len(coeffs))])

    def coeff(self):
        coeff = [random.randrange(0, self.field_size) for _ in range(self.m - 1)]
        coeff.append(self.secret)
        return coeff

    def generate_shares(self):
        cfs = self.coeff()
        shares = []
        for i in range(1, self.n + 1):
            r = random.randrange(1, self.field_size)
            shares.append([r, self.polynom(r, cfs)])
        self.shares = shares

    def get_shares(self):
        return self.shares

    @staticmethod
    def int_from_string(s: str):
        aux = s.encode('utf-8')
        acc = 0
        for b in aux:
            acc *= 256
            acc += b
        return acc

    @staticmethod
    def string_from_int(shares: list, field_size: int):
        x_s = [s[0] for s in shares]
        acc = Mod(0, field_size)
        for i in range(len(shares)):
            others = list(x_s)
            cur = others.pop(i)
            factor = Mod(1, field_size)
            for el in others:
                factor *= el * (el - cur).inverse()
            acc += factor * shares[i][1]
        return acc

    @staticmethod
    def find_best_prime_for_secret(secret: int):
        if secret < FIELD_SIZE_SMALL:
            return FIELD_SIZE_SMALL
        elif secret < FIELD_SIZE_MEDIUM:
            return FIELD_SIZE_MEDIUM
        elif secret < FIELD_SIZE_128BITS:
            return FIELD_SIZE_128BITS
        elif secret < FIELD_SIZE_512BITS:
            return FIELD_SIZE_512BITS
        elif secret < FIELD_SIZE_HUGE:
            return FIELD_SIZE_HUGE
        else:
            raise Exception('no prime that works')


class ShamirDecoder:
    @staticmethod
    def reconstruct_secret(shares: list):
        sums, prod_arr = 0, []
        for j in range(len(shares)):
            xj, yj = shares[j][0], shares[j][1]
            prod = Decimal(1)
            for i in range(len(shares)):
                xi = shares[i][0]
                if i != j:
                    prod *= Decimal(Decimal(xi) / (xi - xj))
            prod *= yj
            sums += Decimal(prod)
        return int(round(Decimal(sums), 0))


if __name__ == '__main__':
    password = 1234
    total, minimum_shares = 5, 3

    encoder = ShamirEncoder(nshares=total, minimum=minimum_shares, secrets=password)
    encoder.generate_shares()
    shares = encoder.get_shares()
    print('Shares:')
    for sh in shares:
        print(sh)

    pool = random.sample(shares, minimum_shares)
    print('\nUsing shares:')
    for sh in pool:
        print(sh)

    reconstructed_secret = ShamirDecoder.reconstruct_secret(pool)
    print(f'\nReconstructed secret: {reconstructed_secret}')

4

0 回答 0