如果我们为每个允许的三个字母字符串分配一个概率,p(abc)、p(abd)、p(acd) 等 xtc,我们可以生成一系列方程
eqn1: p(abc) + p(abd) + ... others with a "a" ... = p1
eqn2: p(abd) + p(acd) + ... others with a "d" ... = p4
这比方程式有更多的未知数,因此有很多解决方法。找到解决方案后,无论您选择哪种方法(如果您是我,请使用单纯形算法),使用@alestanis 描述的轮盘赌方法从每个字符串的概率中采样。
from numpy import *
# using cvxopt-1.1.5
from cvxopt import matrix, solvers
# Functions to do some parts
# function to find all valid outputs
def perms(alphabet, length):
if length == 0:
yield ""
for i in range(len(alphabet)):
val1 = alphabet[i]
for val2 in perms(alphabet[:i]+alphabet[i+1:], length-1):
yield val1 + val2
# roulette sampler
def roulette_sampler(values, probs):
# Create cumulative prob distro
probs_cum = [sum(probs[:i+1]) for i in range(n_strings)]
def fun():
r = random.rand()
for p,s in zip(probs_cum, values):
if r < p:
return s
# in case of rounding error
return values[-1]
return fun
# Main Part
# create list of all valid strings
alphabet = "abcd"
string_length = 3
alpha_probs = [string_length*x/30. for x in range(9,5,-1)]
# show probs
for a,p in zip(alphabet, alpha_probs):
print "p("+a+") =",p
# all valid outputs for this particular case
strings = [perm for perm in perms(alphabet, string_length)]
n_strings = len(strings)
# constraints from probabilities p(abc) + p(abd) ... = p(a)
contains = array([[1. if s.find(a) >= 0 else 0. for a in alphabet] for s in strings])
#both = concatenate((contains,wons), axis=1).T # hacky, but whatever
#A = matrix(both)
#b = matrix(alpha_probs + [1.])
A = matrix(contains.T)
b = matrix(alpha_probs)
#also need to constrain to [0,1]
wons = array([[1. for s in strings]])
G = matrix(concatenate((eye(n_strings),wons,-eye(n_strings),-wons)))
h = matrix(concatenate((ones(n_strings+1),zeros(n_strings+1))))
## target matricies for approx KL divergence
# uniform prob over valid outputs
u = 1./len(strings)
P = matrix(eye(n_strings))
q = -0.5*u*matrix(ones(n_strings))
# will minimise p^2 - pq for each p val equally
# Do convex optimisation
sol = solvers.qp(P,q,G,h,A,b)
probs = array(sol['x'])
# Print ouput
for s,p in zip(strings,probs):
print "p("+s+") =",p
checkprobs = [0. for char in alphabet]
for a,i in zip(alphabet, range(len(alphabet))):
for s,p in zip(strings,probs):
if s.find(a) > -1:
checkprobs[i] += p
print "p("+a+") =",checkprobs[i]
print "total =",sum(probs)
# Create the sampling function
rndstring = roulette_sampler(strings, probs)
# Verify
print "sampling..."
test_n = 1000
output = [rndstring() for i in xrange(test_n)]
# find which one it is
sampled_freqs = []
for char in alphabet:
n = 0
for val in output:
if val.find(char) > -1:
n += 1
sampled_freqs += [n]
print "plotting histogram..."
import matplotlib.pyplot as plt
plt.bar(range(0,len(alphabet)),array(sampled_freqs)/float(test_n), width=0.5)