2

我目前正在使用 Python 工作,但遇到了一个我不知道该去哪里找救命稻草的问题。如果这在某个地方的一些初始算法 CS 课程中有所涉及,请原谅我,我的背景实际上是经济学。我正在处理财务数据,我知道输出和输入,我只是不知道如何达到操作顺序。

例如,我的最终市盈率为 2,但输入为 10(价格)和 5(收益)。看看这个,我知道 10/5 等于 2。但是,问题在于运算的顺序……这可能是加法、乘法、除法和平方根。

如果我刚刚拥有,这部分似乎是可行的

inputs = [10,5]
output = 2

def deduction_int(inputs, output):
    initial_output = 0
    while initial_output != output:
    try adding, try subtracting (inverse), try dividing(inverse)

当它自己弄清楚或有答案时打印'yay'

上面的代码看起来很明显而且很快,但是,当你向它添加 3 个变量时......

输入:10、5、7 输出:2.14

以及 (10 + 5) / 7 = 2.14 等情况。

我遇到了数字可能以不同顺序运行的情况。例如,在除以 7 之前运行 10+5。这是常见的算法类型问题吗?如果是这样,我究竟在哪里寻找一些教科书描述(算法名称,教科书)?

谢谢!

4

2 回答 2

2

这是一个蛮力算法。

from __future__ import division
import itertools as IT
import operator

opmap = {operator.add: '+',
         operator.mul: '*',
         operator.truediv: '/'}
operators = opmap.keys()

def deduction_int(inputs, output):
    iternums = IT.permutations(inputs, len(inputs))
    iterops = IT.product(operators, repeat=len(inputs)-1)
    for nums, ops in IT.product(iternums, iterops):
        for result, rstr in combine(nums, ops):
            if near(result, output, atol=1e-3):
                return rstr

def combine(nums, ops, astr=''):
    a = nums[0]
    astr = astr if astr else str(a)
    try:
        op = ops[0]
    except IndexError:
        return [(a, astr)]
    # combine a op (...)
    result = []
    for partial_val, partial_str in combine(nums[1:], ops[1:]):
        r = op(a, partial_val)
        if len(nums[1:]) > 1:
            rstr = '{}{}({})'.format(astr, opmap[op], partial_str)
        else:
            rstr = '{}{}{}'.format(astr, opmap[op], partial_str)
        assert near(eval(rstr), r)
        result.append((r, rstr))
    # combine (a op ...)
    b = nums[1]
    astr = '({}{}{})'.format(astr,opmap[op], b)
    for partial_val, partial_str in combine((op(a, b),)+nums[2:], ops[1:],
                                            astr):
        assert near(eval(partial_str), partial_val)
        result.append((partial_val, partial_str))
    return result

def near(a, b, rtol=1e-5, atol=1e-8):
    return abs(a - b) < (atol + rtol * abs(b))

def report(inputs, output):
    rstr = deduction_int(inputs, output)
    return '{} = {}'.format(rstr, output)

print(report([10,5,7], (10+5)/7))
print(report([1,2,3,4], 3/7.))
print(report([1,2,3,4,5], (1+(2/3)*(4-5))))

产量

(10+5)/7 = 2.14285714286
(1+2)/(3+4) = 0.428571428571
(1+5)/((2+4)*3) = 0.333333333333

主要思想是简单地枚举输入值的所有排序以及运算符的所有排序。例如,

In [19]: list(IT.permutations([10,5,7], 3))
Out[19]: [(10, 5, 7), (10, 7, 5), (5, 10, 7), (5, 7, 10), (7, 10, 5), (7, 5, 10)]

然后将输入值的每个排序与运算符的每个排序配对:

In [38]: list(IT.product(iternums, iterops))
Out[38]: 
[((10, 5, 7), (<built-in function add>, <built-in function mul>)),
 ((10, 5, 7), (<built-in function add>, <built-in function truediv>)),
 ((10, 5, 7), (<built-in function mul>, <built-in function add>)),
 ((10, 5, 7), (<built-in function mul>, <built-in function truediv>)),
 ...

combine函数对 nums 和 ops 进行排序,并枚举所有可能的 nums 和 ops 分组:在 [65] 中:combine((10, 5, 7), (operator.add, operator.mul) )

Out[65]: [(45, '10+(5*7)'), (45, '10+((5*7))'), (105, '(10+5)*7'), (105, '((10+5)*7)')]

它返回一个元组列表。每个元组是一个 2 元组,由一个数值和rstr计算为该值的分组操作的字符串表示形式 组成。

因此,您只需遍历所有可能性并返回rstrwhich,当评估时,会产生一个接近output.

for nums, ops in IT.product(iternums, iterops):
    for result, rstr in combine(nums, ops):
        if near(result, output, atol=1e-3):
            return rstr

一些有用的参考:

于 2013-10-03T21:18:51.540 回答
1

所以你得到了一些输入和一个输出,你想找到产生它的表达式。

做到这一点的简单方法是通过蛮力,通过生成和测试各种表达式。我的程序通过从以数字开头的简单表达式构建大表达式来做到这一点。它一遍又一遍地添加新生成的表达式与它们之前的所有内容的组合。

它打印出从简单到复杂的解决方案,直到内存不足。

#!python3

import operator
import decimal
import sys

# Automatically take care of divisions by zero etc
decimal.setcontext(decimal.ExtendedContext)

class Expression(object):
    def __init__(self, left, right):
        self.left = left
        self.right = right

class Number(Expression):
    def __init__(self, value):
        self.value = decimal.Decimal(value)

    def evaluate(self):
        return self.value

    def __str__(self):
        return str(self.value)

class Addition(Expression):
    def evaluate(self):
        return self.left.evaluate() + self.right.evaluate()

    def __str__(self):
        return "({0} + {1})".format(self.left, self.right)

class Subtraction(Expression):
    def evaluate(self):
        return self.left.evaluate() - self.right.evaluate()

    def __str__(self):
        return "({0} - {1})".format(self.left, self.right)

class Multiplication(Expression):
    def evaluate(self):
        return self.left.evaluate() * self.right.evaluate()

    def __str__(self):
        return "({0} * {1})".format(self.left, self.right)

class Division(Expression):
    def evaluate(self):
        return self.left.evaluate() / self.right.evaluate()

    def __str__(self):
        return "({0} / {1})".format(self.left, self.right)

class Sqrt(Expression):
    def __init__(self, subexp):
        self.subexp = subexp

    def evaluate(self):
        return self.subexp.evaluate().sqrt()

    def __str__(self):
        return "sqrt({0})".format(self.subexp)

def bruteforce(inputs, output, wiggle):
    inputs = [Number(i) for i in inputs]
    output = decimal.Decimal(output)
    wiggle = decimal.Decimal(wiggle)

    expressions = inputs
    generated = inputs

    while True:
        newgenerated = []
        for g in generated:
            for e in expressions:
                newgenerated.extend([
                    Addition(g, e),
                    Subtraction(g, e),
                    Multiplication(g, e),
                    Division(g, e)
                ])
            for e in expressions[0:len(expressions) - len(generated)]:
                # Subtraction and division aren't commutative. This matters
                # when the relation is not symmetric. However it is symmetric
                # for the most recently generated elements, so we don't worry
                # about commutivity for those.
                newgenerated.extend([
                    Division(e, g),
                    Subtraction(e, g)
                ])
        newgenerated.append(Sqrt(g))

        for c in newgenerated:
            if abs(c.evaluate() - output) < decimal.Decimal(.01):
                print(c)
                sys.stdout.flush()

        expressions.extend(newgenerated)
        generated = newgenerated

bruteforce((10, 5, 7), 2.14, .005)

印刷

((10 + 5) / 7)
((10 - 7) * (5 / 7))
((10 - 7) / (7 / 5))
((10 / 7) + (5 / 7))
((5 + 10) / 7)
((5 / 7) * (10 - 7))
((5 / 7) + (10 / 7))
(sqrt(7) - (5 / 10))

这些都没有精确地评估为 2.14,但它们在 0.005 的“摆动”内是相同的。到小数点后 3 位,它们都是 2.143,除了 sqrt 是 2.146。

在生成这些之后,它当然会因 MemoryError 而崩溃。我什至不想知道这个的时间或空间复杂度:)

于 2013-10-04T01:38:20.307 回答