我已将其重写如下:
首先,我创建了一个 Nominal 类型工厂:
class BaseNominalType:
name_values = {} # <= subclass must override this
def __init__(self, name):
self.name = name
self.value = self.name_values[name]
def __str__(self):
return self.name
def __sub__(self, other):
assert type(self) == type(other), "Incompatible types, subtraction is undefined"
return self.value - other.value
# class factory function
def make_nominal_type(name_values):
try:
nv = dict(name_values)
except ValueError:
nv = {item:i for i,item in enumerate(name_values)}
# make custom type
class MyNominalType(BaseNominalType):
name_values = nv
return MyNominalType
现在我可以定义你的名义类型,
Forest = make_nominal_type(["shrubs", "plantation", "forestry", "other"])
Level = make_nominal_type(["low", "medium", "high"])
Bool = make_nominal_type({"f":False, "t":True})
然后我创建了一个 MixedVector 类型工厂:
# base class
class BaseMixedVectorType:
types = [] # <= subclass must
distance_fn = None # <= override these
def __init__(self, values):
self.values = [type_(value) for type_,value in zip(self.types, values)]
def dist(self, other):
return self.distance_fn([abs(s - o) for s,o in zip(self.values, other.values)])
# class factory function
def make_mixed_vector_type(types, distance_fn):
tl = list(types)
df = distance_fn
class MyVectorType(BaseMixedVectorType):
types = tl
distance_fn = df
return MyVectorType
然后创建您的数据类型,
# your mixed-vector type
DataItem = make_mixed_vector_type(
[Forest, Forest, Level, Level, Level, Level, int, Level, int, int, Bool],
??? # have to define an appropriate distance function!
)
...但是等等,我们还没有定义距离函数!我编写了这个类,让你可以插入任何你喜欢的距离函数,形式为:
def manhattan_dist(_, vector):
return sum(vector)
def euclidean_dist(_, vector):
return sum(v*v for v in vector) ** 0.5
# the distance function per your description:
def fractional_match_distance(_, vector):
return float(sum(not v for v in vector)) / len(vector)
所以我们完成了创作
# your mixed-vector type
DataItem = make_mixed_vector_type(
[Forest, Forest, Level, Level, Level, Level, int, Level, int, int, Bool],
fractional_match_distance
)
并将其测试为
def main():
raw_data = [
('forestry', 'plantation', 'high', 'low', 'high', 'medium', 3, 'low', 297, 1, 't'),
('plantation', 'plantation', 'high', 'medium', 'low', 'low', 1, 'low', 298, 2, 't'),
('other', 'shrubs', 'medium', 'high', 'medium', 'high', 0, 'high', 299, 0, 't'),
('forestry', 'forestry', 'low', 'high', 'high', 'medium', 4, 'medium', 297, 4, 'f')
]
a, b, c, d = [DataItem(d) for d in raw_data]
print("a to b, dist = {}".format(a.dist(b)))
print("b to c, dist = {}".format(b.dist(c)))
print("c to d, dist = {}".format(c.dist(d)))
if __name__=="__main__":
main()
这给了我们
a to b, dist = 0.363636363636
b to c, dist = 0.0909090909091
c to d, dist = 0.0909090909091