1

我正在研究/评估解决二次丢番图方程组的技术方法。我的具体问题可以归结为以下两个步骤:

  1. 加载包含元组行的文本文件,[sqrt(s), sqrt(t), sqrt(u), s, t, u, t+u, t+u-s, t-s]其中每个元素都是整数。该文件的截图如下。
  2. 对于此文件中的每一行:搜索[w,x,y,z]求解以下方程组的整数四元组: [x^2-w^2=s][y^2-w^2=t][z^2-y^2=u][z^2-w^2=t+u]和。[z^2-x^2=t+u-s][y^2-x^2=t-s]

这是文本文件的截图:

520, 533, 756, 270400, 284089, 571536, 855625, 585225, 13689
672, 680, 153, 451584, 462400, 23409, 485809, 34225, 10816
756, 765, 520, 571536, 585225, 270400, 855625, 284089, 13689
612, 740, 2688, 374544, 547600, 7225344, 7772944, 7398400, 173056
644, 725, 2040, 414736, 525625, 4161600, 4687225, 4272489, 110889

到目前为止,我尝试的是使用z3求解器,它可以编译并运行,但不幸的是速度很慢:

import pandas as pd
import sys
from z3 import Ints, solve

def main() -> int:
    df = pd.read_csv('tuples.txt', header=None)
    
    tuples = df.to_numpy()

    x, y, z, w = Ints('x y z w')
    for row in tuples:
        s=int(row[3])
        t=int(row[4])
        u=int(row[5])
        solve(x*x-w*w==s, y*y-w*w==t, z*z-y*y==u, w!=0)

    return 0

if __name__ == '__main__':
    sys.exit(main())

如果有任何替代方法(最佳实践)在 Python 中解决这种丢番图系统,我将非常感激。

4

1 回答 1

3

在 Python 中为您创建了非常庞大但非常快速的解决方案。z3它应该比任何做类似事情的 Mathematica 代码或 -solver 代码解决得更快。当然,在预计算阶段之后,它只完成一次,然后可以在多次运行中重复使用(它们将所有计算数据保存到缓存文件中)。

以下解决方案进行了两次预计算。第一个需要几分钟,它预先计算 2.7 GB 的文件,这是一个大的正方形过滤器。这个尺寸是可调整的,可以更小。该文件仅计算一次(除非您更改设置)并在每次运行时重复使用。这种预计算是单核的(但经过一些努力,我可以使它成为多核)。

第二个预计算需要更多时间,这个是多核的,它使用所有 CPU 内核。此预计算生成的文件非常小,即使对于较大的参数值也不到 1 GB。该预计算表存储所有可能的具有整数边的毕达哥拉斯直角三角形。

对所有小于limit尺寸的导管进行预计算。将当前设置更改limit = 100_000为更大的值,在您的情况下可能为 1M。如果这个表太小,那么它将无法为大型导管找到一些解决方案。预先计算的表也存储在磁盘上,并在每次运行时重复使用(不再计算)。

第二个预计算计算以下直角三角形表。它遍历所有可能的第一个整数cathetus A(达到限制)并找到所有可能的第二个整数cathetus B(达到限制)使得A^2 + B^2 = C^2,其中C 也是整数。然后对于每个 A,它存储一组满足该方程的 B。C 没有被存储,因为它可以很容易地从 A 和 B 中计算出来。

为了快速搜索 B,我构建了两个过滤器。例如,我们有任何整数 K0 和 K1。我们可以很容易地看出,如果 X 是一个正方形,那么 X % K0 是一个正方形,X % K1 也是一个正方形。因此,我们可以构建一个大小为 K0 的表,如果 table[X % K0] 是平方模 K0 则为 1,否则为 0。这为我们提供了一个快速过滤器,用于删除所有此类绝对非正方形的 X(即 table[X % K0] 为 0)。第二个 K1 滤波器用于第二阶段的额外滤波。

以下 Python 代码可以立即运行,无需依赖,它会自动从您的 GitHub 获取 STU 文件并将其缓存在磁盘上。

完成上述两次预计算后,所有数千个 s/t/u 解决方案都在 1-2 秒内计算完毕。最后,所有解决方案都以 JSON 格式存储到文件stu_solutions.100000中。

找到的几乎解决方案(具有非整数 Z)可以通过命令转储:

cat stu_solutions.100000 | grep false

找到的精确解(整数 Z)可以通过命令转储:

cat stu_solutions.100000 | grep true

该文件的其余行包含有错误的解决方案(如果表对他们来说太小),或者在找不到 w、x、y 时包含零解决方案。如果出现错误,您必须通过设置更大的limit = ....

有必要设置至少与 一样大的限制Max(Sqrt(s), Sqrt(t))。但最好将其设置为大几倍。表越大,找到所有可能的解决方案的机会就越高。限制最多需要尽可能大w

要运行以下 Python 代码,您必须安装一次性 PIP 包python -m pip install numba numpy requests

在线尝试!

numba = None
import numba

import json, multiprocessing, time, timeit, os, math, numpy as np

if numba is None:
    class NumbaInt:
        def __getitem__(self, key):
            return None
    class numba:
        uint8, uint16, uint32, int64, uint64 = [NumbaInt() for i in range(5)]
        def njit(*pargs, **nargs):
            return lambda f: f
        def prange(*pargs):
            return range(*pargs)
        class types:
            class Tuple:
                def __init__(self, *nargs, **pargs):
                    pass
                def __call__(self, *nargs, **pargs):
                    pass
        class objmode:
            def __init__(self, *pargs, **nargs):
                pass
            def __enter__(self):
                return self
            def __exit__(self, ext, exv, tb):
                pass

@numba.njit(cache = True, parallel = True)
def create_filters():
    Ks = [np.uint32(e) for e in [2 * 2 * 3 * 5 * 7 * 11 * 13,    17 * 19 * 23 * 29 * 31 * 37]]
    filts = []
    for i in range(len(Ks)):
        K = Ks[i]
        filt = np.zeros((K,), dtype = np.uint8)
        block = 1 << 25
        nblocks = (K + block - 1) // block
        for j0 in numba.prange(nblocks):
            j = j0 * block
            a = np.arange(j, min(j + block, K)).astype(np.uint64)
            a *= a; a %= K
            filt[a] = 1
        idxs = np.flatnonzero(filt).astype(np.uint32)
        filts.append((K, filt, idxs))
        print(f'filter {i} ratio', round(len(filts[-1][2]) / K, 4))
    return filts

@numba.njit('u2[:, :, :](u4, u4[:])', cache = True, parallel = True, locals = dict(
    t = numba.uint32, s = numba.uint32, i = numba.uint32, j = numba.uint32))
def filter_chain(K, ix):
    assert len(ix) < (1 << 16)
    ix_rev = np.full((K,), len(ix), dtype = np.uint16)
    for i, e in enumerate(ix):
        ix_rev[e] = i
    r = np.zeros((len(ix), K, 2), dtype = np.uint16)
    
    print('filter chain pre-computing...')
    
    for i in numba.prange(K):
        if i % 5000 == 0 or i + 1 >= K:
            with numba.objmode():
                print(f'{i}/{K}, ', end = '', flush = True)
        for j, x in enumerate(ix):
            t, s = i, x
            while True:
                s += 2 * t + 1; s %= K
                t += 1
                if ix_rev[s] < len(ix):
                    assert t - i < (1 << 16)
                    assert t - i < K
                    r[j, i, 0] = ix_rev[s]
                    r[j, i, 1] = np.uint16(t - i)
                    break
    
    print()
    
    return r

def filter_chain_create_load(K, ix):
    fname = f'filter_chain.{K}'
    if not os.path.exists(fname):
        r = filter_chain(K, ix)
        with open(fname, 'wb') as f:
            f.write(r.tobytes())
    with open(fname, 'rb') as f:
        return np.copy(np.frombuffer(f.read(), dtype = np.uint16).reshape(len(ix), K, 2))

@numba.njit(
    #'void(i8, i8, u4, u1[:], u4[:], u2[:, :, :], u4, u1[:])',
    numba.types.Tuple([numba.uint64[:], numba.uint32[:]])(
        numba.int64, numba.int64, numba.uint32, numba.uint8[:],
        numba.uint32[:], numba.uint16[:, :, :], numba.uint32, numba.uint8[:]),
    cache = True, parallel = True,
    locals = dict(x = numba.uint64, Atpos = numba.uint64, Btpos = numba.uint64, bpos = numba.uint64))
def create_table(limit, cpu_count, k0, f0, fi0, fc0, k1, f1):
    print('Computing tables...')
    
    def gen_squares_candidates_A(cnt, lim, off, t, K, f, fi, fc):
        mark = np.zeros((np.int64(K),), dtype = np.uint8)
        while True:
            start_s = np.int64((np.int64(off) + np.int64(t) ** 2) % K)
            tK = np.uint32(np.int64(t) % np.int64(K))
            if mark[tK]:
                return np.zeros((0,), dtype = np.uint32)
            mark[tK] = 1
            if f[start_s]:
                break
            t += 1
        j = np.int64(np.searchsorted(fi, start_s))
        assert fi[j] == start_s
        r = np.zeros((np.int64(cnt),), dtype = np.uint32)
        r[0] = t
        rpos = np.int64(1)
        tK = np.uint32(np.int64(t) % np.int64(K))
        while True:
            j, dt = fc[j, tK]
            t += dt
            tK += dt
            if tK >= np.uint32(K):
                tK -= np.uint32(K)
            if t >= lim:
                return r[:rpos]
            if np.int64(rpos) >= np.int64(r.shape[0]):
                r = np.concatenate((r, np.zeros_like(r)), axis = 0)
            assert rpos < len(r)
            r[rpos] = t
            rpos += 1
    
    def gen_squares(cnt, lim, off, t, K, f, fi, fc, k1, f1):
        def is_square(x):
            assert x >= 0
            if not f1[np.int64(x) % np.uint32(k1)]:
                return False
            root = np.uint64(math.sqrt(np.float64(x)) + 0.5)
            return root * root == x
        rA = gen_squares_candidates_A(cnt, lim, off, t, K, f, fi, fc)
        r = np.zeros_like(rA)
        rpos = np.int64(0)
        for t in rA:
            if not is_square(np.int64(off) + np.int64(t) ** 2):
                continue
            assert np.int64(rpos) < np.int64(r.shape[0])
            r[rpos] = t
            rpos += 1
        return r[:rpos]
    
    with numba.objmode(gtb = 'f8'):
        gtb = time.time()
    
    search_start = 2
    cnt_limit = max(1 << 4, round(pow(limit, 0.66)))
    
    nblocks2 = cpu_count * 8
    nblocks = nblocks2 * 64
    block = (limit + nblocks - 1) // nblocks
    
    At = np.zeros((limit + 1,), dtype = np.uint64)
    Bt = np.zeros((0,), dtype = np.uint32)
    Atpos, Btpos = search_start + 1, 0
    
    with numba.objmode(tb = 'f8'):
        tb = time.time()
    for iMblock in range(0, nblocks, nblocks2):
        cur_blocks = min(nblocks, iMblock + nblocks2) - iMblock
        As = np.zeros((cur_blocks, block), dtype = np.uint64)
        As_size = np.zeros((cur_blocks,), dtype = np.uint64)
        Bs = np.zeros((cur_blocks, 1 << 16,), dtype = np.uint32)
        Bs_size = np.zeros((cur_blocks,), dtype = np.uint64)
        for iblock in numba.prange(cur_blocks):
            iblock0 = iMblock + iblock
            begin, end = max(search_start, iblock0 * block), min(limit, (iblock0 + 1) * block)
            begin = min(begin, end)
            #a = np.zeros((block,), dtype = np.uint64)
            #b = np.zeros((1 << 10,), dtype = np.uint32)
            bpos = 0
            for ix, x in enumerate(range(begin, end)):
                s = gen_squares(cnt_limit, limit, x ** 2, search_start, k0, f0, fi0, fc0, k1, f1)
                assert not (np.int64(bpos) + np.int64(s.shape[0]) > np.int64(Bs[iblock].shape[0]))
                #while np.int64(bpos) + np.int64(s.shape[0]) > np.int64(b.shape[0]):
                #    b = np.concatenate((b, np.zeros_like(b)), axis = 0)
                bpos_end = bpos + s.shape[0]
                Bs[iblock, bpos : bpos_end] = s
                As[iblock, ix] = bpos_end
                bpos = bpos_end
            As_size[iblock] = end - begin
            Bs_size[iblock] = bpos
        for iblock, (cA, cB) in enumerate(zip(As, Bs)):
            cA = cA[:As_size[iblock]]
            cB = cB[:Bs_size[iblock]]
            assert Atpos + cA.shape[0] <= At.shape[0]
            prevA = At[Atpos - 1]
            for e in cA:
                At[Atpos] = prevA + e
                Atpos += 1
            #while np.int64(Btpos) + np.int64(cB.shape[0]) > np.int64(Bt.shape[0]):
                #Bt = np.concatenate((Bt, np.zeros_like(Bt)), axis = 0)
                #Bt = np.concatenate((Bt, np.zeros(Bt.shape, dtype = np.uint32)), axis = 0)
            #assert np.int64(Btpos) + np.int64(cB.shape[0]) <= np.int64(Bt.shape[0])
            #assert cB.shape[0] > 0
            #Bt[Btpos : Btpos + cB.shape[0]] = cB
            Bt = np.concatenate((Bt, cB))
            #Btpos += cB.shape[0]
            #assert At[Atpos - 1] == Btpos
            assert At[Atpos - 1] == Bt.shape[0]
        with numba.objmode(tim = 'f8'):
            tim = max(0.001, round(time.time() - tb, 3))
        print(f'{str(min(limit, (iMblock + cur_blocks) * block) >> 10).rjust(len(str(limit >> 10)))}/{limit >> 10} K, ELA',
            round(tim / 60.0, 1), 'min, ETA', round((nblocks - (iMblock + cur_blocks)) * (tim / (iMblock + cur_blocks)) / 60.0, 1), 'min')
    
    assert Atpos == At.shape[0]
    
    with numba.objmode(gtb = 'f8'):
        gtb = time.time() - gtb
    
    print(f'Tables sizes: A {Atpos}, B {Bt.shape[0]}')
    print('Time elapsed computing tables:', round(gtb / 60.0, 1), 'min')
    
    return At, Bt
    
def table_create_load(limit, *pargs):
    fnameA = f'right_triangles_table.A.{limit}'
    fnameB = f'right_triangles_table.B.{limit}'
    if not os.path.exists(fnameA) or not os.path.exists(fnameB):
        A, B = create_table(limit, *pargs)
        with open(fnameA, 'wb') as f:
            f.write(A.tobytes())
        with open(fnameB, 'wb') as f:
            f.write(B.tobytes())
        del A, B
    with open(fnameA, 'rb') as f:
        A = np.copy(np.frombuffer(f.read(), dtype = np.uint64))
        assert A.shape[0] == limit + 1, (fnameA, A.shape[0], limit + 1)
    with open(fnameB, 'rb') as f:
        B = np.copy(np.frombuffer(f.read(), dtype = np.uint32))
        assert A[-1] == B.shape[0], (fnameB, A[-1], B.shape[0])
    print(f'Table A size {A.shape[0]}, B size {B.shape[0]}')
    return A, B

def find_solutions(tA, tB, stu):
    def is_square(x):
        root = np.uint64(math.sqrt(np.float64(x)) + 0.5)
        return bool(root * root == x), int(root)
    
    assert tA[-1] == tB.shape[0]
    
    fname = f'stu_solutions.{tA.shape[0] - 1}'
    with open(fname, 'w', encoding = 'utf-8') as fout:
        for s, t, u in stu:
            s, t, u = map(int, (s, t, u))
            r = {'stu': [s, t, u]}
            if s + 1 >= tA.shape[0]:
                r['err'] = f's {s} exceeds table A len {tA.shape[0]}'
            elif t + 1 >= tA.shape[0]:
                r['err'] = f't {t} exceeds table A len {tA.shape[0]}'
            else:
                r['res'] = []
                bs = tB[tA[s] : tA[s + 1]]
                ts = tB[tA[t] : tA[t + 1]]
                for w in np.intersect1d(bs, ts):
                    w = int(w)
                    x2 = s ** 2 + w ** 2
                    y2 = t ** 2 + w ** 2
                    x_isq, x = is_square(x2)
                    assert x_isq, (s, t, u, w, x2)
                    y_isq, y = is_square(y2)
                    assert y_isq, (s, t, u, w, x2, y2)
                    z2 = u ** 2 + y2
                    z_isq, z = is_square(z2)
                    r['res'].append({
                        'w': w,
                        'x': x,
                        'y': y,
                        'z2': z2,
                        'z2_is_square': z_isq,
                        'z': z if z_isq else math.sqrt(z2),
                    })
            fout.write(json.dumps(r, ensure_ascii = False) + '\n')
    
    print(f'STU solutions written to {fname}')

def solve(limit):
    import requests
    
    filts = create_filters()
    fc0 = filter_chain_create_load(filts[0][0], filts[0][2])
    
    tA, tB = table_create_load(limit, multiprocessing.cpu_count(),
        filts[0][0], filts[0][1], filts[0][2], fc0, filts[1][0], filts[1][1])
    
    # https://github.com/Sultanow/pythagorean/blob/main/data/pythagorean_stu_Arty_.txt?raw=true
    ifname = 'pythagorean_stu_Arty_.txt'
    iurl = f'https://github.com/Sultanow/pythagorean/blob/main/data/{ifname}?raw=true'
    if not os.path.exists(ifname):
        print(f'Downloading: {iurl}')
        data = requests.get(iurl).content
        with open(ifname, 'wb') as f:
            f.write(data)
    stu = []
    with open(ifname, 'r', encoding = 'utf-8') as f:
        for line in f:
            if not line.strip():
                continue
            if 'elapsed' in line:
                continue
            s, t, u, *_ = eval(f'[{line}]')
            stu.append([s, t, u])
    print(f'Read {len(stu)} s/t/u tuples')
    find_solutions(tA, tB, stu)
    
def main():
    limit = 100_000
    solve(limit)

if __name__ == '__main__':
    main()

输出:

filter 0 ratio 0.0224
filter 1 ratio 0.0199
Table A size 100001, B size 371720
Read 27060 s/t/u tuples
STU solutions written to stu_solutions.100000

对于 50K 限制,所有找到的几乎解决方案的示例(其中只有 Z 不是整数):

{"stu": [3528, 37128, 10175], "res": [{"w": 31654, "x": 31850, "y": 48790, "z2": 2483994725, "z2_is_square": false, "z": 49839.69025786577}]}
{"stu": [7700, 12155, 5460], "res": [{"w": 10608, "x": 13108, "y": 16133, "z2": 290085289, "z2_is_square": false, "z": 17031.89035309939}]}
{"stu": [9405, 12155, 10608], "res": [{"w": 5460, "x": 10875, "y": 13325, "z2": 290085289, "z2_is_square": false, "z": 17031.89035309939}]}
{"stu": [11760, 18564, 31977], "res": [{"w": 13475, "x": 17885, "y": 22939, "z2": 1548726250, "z2_is_square": false, "z": 39353.8594041296}]}
{"stu": [14364, 18564, 13475], "res": [{"w": 31977, "x": 35055, "y": 36975, "z2": 1548726250, "z2_is_square": false, "z": 39353.8594041296}]}
{"stu": [15400, 24310, 10920], "res": [{"w": 21216, "x": 26216, "y": 32266, "z2": 1160341156, "z2_is_square": false, "z": 34063.78070619878}]}
{"stu": [18480, 29172, 13104], "res": [{"w": 21175, "x": 28105, "y": 36047, "z2": 1471101025, "z2_is_square": false, "z": 38354.93481939449}]}
{"stu": [18810, 24310, 21216], "res": [{"w": 10920, "x": 21750, "y": 26650, "z2": 1160341156, "z2_is_square": false, "z": 34063.78070619878}]}
{"stu": [21840, 43500, 30800], "res": [{"w": 14651, "x": 26299, "y": 45901, "z2": 3055541801, "z2_is_square": false, "z": 55276.95542448046}]}
{"stu": [22572, 29172, 21175], "res": [{"w": 13104, "x": 26100, "y": 31980, "z2": 1471101025, "z2_is_square": false, "z": 38354.93481939449}]}
{"stu": [23100, 36465, 16380], "res": [{"w": 31824, "x": 39324, "y": 48399, "z2": 2610767601, "z2_is_square": false, "z": 51095.67105929816}]}
{"stu": [23520, 37128, 63954], "res": [{"w": 26950, "x": 35770, "y": 45878, "z2": 6194905000, "z2_is_square": false, "z": 78707.7188082592}]}
{"stu": [28215, 36465, 31824], "res": [{"w": 16380, "x": 32625, "y": 39975, "z2": 2610767601, "z2_is_square": false, "z": 51095.67105929816}]}
{"stu": [30800, 48620, 21840], "res": [{"w": 42432, "x": 52432, "y": 64532, "z2": 4641364624, "z2_is_square": false, "z": 68127.56141239755}]}
{"stu": [36960, 37128, 31654], "res": [{"w": 10175, "x": 38335, "y": 38497, "z2": 2483994725, "z2_is_square": false, "z": 49839.69025786577}]}
{"stu": [37620, 43500, 14651], "res": [{"w": 30800, "x": 48620, "y": 53300, "z2": 3055541801, "z2_is_square": false, "z": 55276.95542448046}]}
{"stu": [37620, 48620, 42432], "res": [{"w": 21840, "x": 43500, "y": 53300, "z2": 4641364624, "z2_is_square": false, "z": 68127.56141239755}]}
于 2022-01-25T17:38:05.610 回答