3

我正在为 Hadoop 进行 R 中的分布式线性回归计算,但在实现它之前,我想验证我的计算是否与lm函数的结果一致。

我有以下函数试图实现 Andrew Ng 等人讨论的通用“求和”框架。在论文Map-Reduce for Machine Learning on Multicore中。

对于线性回归,这涉及将每行 y_i 和 x_i 映射到 P_i 和 Q_i,使得:

P_i = x_i * transpose(x_i)
Q_i = x_i * y_i

然后减少以求解系数,theta: theta = (sum(P_i))^-1 * sum(Q_i)

执行此操作的 R 函数是:

calculate_p <- function(dat_row) {
  dat_row %*% t(dat_row)
}

calculate_q <- function(dat_row) {
  dat_row[1,1] * dat_row[, -1]
}

calculate_pq <- function(dat_row) {
  c(calculate_p(matrix(dat_row[-1], nrow=1)), calculate_q(matrix(dat_row, nrow=1)))
}

map_pq <- function(dat) {
  t(apply(dat, 1, calculate_pq))
}

reduce_pq <- function(pq) {
  (1 / sum(pq[, 1])) * apply(pq[, -1], 2, sum)
}

您可以通过运行在一些合成数据上实现它:

X <- matrix(rnorm(20*5), ncol = 5)
y <- as.matrix(rnorm(20))
reduce_pq(map_pq(cbind(y, X)))
[1]  0.010755882 -0.006339951 -0.034797768  0.067438662 -0.033557351
coef(lm.fit(X, y))
          x1           x2           x3           x4           x5 
-0.038556283 -0.002963991 -0.195897701  0.422552974 -0.029823962

不幸的是,输出不匹配,所以很明显我做错了什么。有什么想法可以解决吗?

4

1 回答 1

3

您接受的逆reduce_pq需要是矩阵逆。我也稍微改变了一些功能。

calculate_p <- function(dat_row) { 
    dat_row %*% t(dat_row)
}

calculate_q <- function(dat_row) { 
    dat_row[1] * dat_row[-1] 
}

calculate_pq <- function(dat_row) {
    c(calculate_p(dat_row[-1]), calculate_q(dat_row)) 
}

map_pq <- function(dat) {
    t(apply(dat, 1, calculate_pq))
}

reduce_pq <- function(pq) { 
    solve(matrix(apply(pq[, 1:(ncol(X) * ncol(X))], 2, sum), nrow=ncol(X))) %*% apply(pq[, 1:ncol(X) + ncol(X)*ncol(X)], 2, sum)
}


set.seed(1)
X <- matrix(rnorm(20*5), ncol = 5)
y <- as.matrix(rnorm(20))

t(reduce_pq(map_pq(cbind(y, X))))
          [,1]      [,2]      [,3]       [,4]        [,5]
[1,] 0.1236914 0.2482445 0.5120975 -0.1104451 -0.04080922

coef(lm.fit(X,y))
         x1          x2          x3          x4          x5 
 0.12369137  0.24824449  0.51209753 -0.11044507 -0.04080922 

> all.equal(as.numeric(t(reduce_pq(map_pq(cbind(y, X))))), as.numeric(coef(lm.fit(X,y))))
[1] TRUE
于 2012-10-11T04:20:53.483 回答