0

我在 GitHub 上找到了这段代码,并试图理解函数的行为。我试图将这段代码与公式进行比较(来自本文的第 6 页):

在此处输入图像描述

我找不到这些公式在代码中实现的位置。谁能帮我解释一下公式和代码之间的相似之处?

class FCM() :
    def __init__(self, n_clusters=17, max_iter=100, m=2, error=1e-6):
        super().__init__()
        self.u, self.centers = None, None
        self.n_clusters = n_clusters
        self.max_iter = max_iter
        self.m = m
        self.error = error

def fit(self, X):
    N = X.shape[0]
    C = self.n_clusters
    centers = []

    u = np.random.dirichlet(np.ones(C), size=N)

    iteration = 0
    while iteration < self.max_iter:
        u2 = u.copy()

        centers = self.next_centers(X, u)
        u = self.next_u(X, centers)
        iteration += 1

        # Stopping rule
        if norm(u - u2) < self.error:
            break

    self.u = u
    self.centers = centers
    return centers

def next_centers(self, X, u):
    um = u ** self.m
    return (X.T @ um / np.sum(um, axis=0)).transpose()  #Vi

def next_u(self, X, centers):
    return self._predict(X, centers)

def _predict(self, X, centers):
    power = float(2 / (self.m - 1))
    temp = cdist(X, centers) ** power
    denominator_ = temp.reshape((X.shape[0], 1, -1)).repeat(temp.shape[-1], axis=1)
    denominator_ = temp[:, :, np.newaxis] / denominator_

    return 1 / denominator_.sum(2)

def predict(self, X):
    if len(X.shape) == 1:
        X = np.expand_dims(X, axis=0)

    u = self._predict(X, self.centers)
    return np.argmax(u, axis=-1)

img2 = ret.reshape(x * y, z)
    algorithm = FCM()
    cluster_centers = algorithm.fit(img2)
    output = algorithm.predict(img2)
    img = cluster_centers[output].astype(np.int16).reshape(x, y, 3)
4

1 回答 1

0

我认为最相关的部分如下:

  • 该函数fit描述了算法的一般步骤:
    • 一些初始化
    • 中心和标签的更新(委托给其他功能)
    • 每次迭代后对停止标准的验证 ( norm(u - u2) < self.error)。
  • 等式(4)中的模糊隶属度在函数内部实现_predict
    • 指数存储在power
    • 差异规范由cdistscipy.spatial.distance 处理。对该函数的一次调用用于计算数据点 x j和聚类中心 vi 的所有组合之间的距离。每个结果都被提升到power并且结果存储在一个temp数组中。
    • _predict围绕数组条目的最后 3 行temp以这样一种方式进行处理:在没有循环的情况下,它计算每个“距离”和“距离”的适当总和(实际上是“距离的” power)之间的除法,但是一开始更容易忽略指数)。为此,代码使用了 numpy 中可用的一些技巧,例如reshaperepeatindexing
  • 等式(5)中的模糊中心由下式计算next_centers
    • um预先计算一个包含每个 u ijm的-th 幂的矩阵(供以后在分子和分母中使用)
    • 用于np.sum(um, axis=0)创建一个数组,其第-个条目是计算第-个簇i时在分母中使用的总和i
    • 对于第i-th 簇,(5)的分子将在i第 -th 列中计算(一旦我们得到结果矩阵,它就X.T @ um变成一行).transpose()
于 2021-01-05T19:21:16.473 回答