2

我正在使用 CVXR 来编写惩罚线性回归。我的全局损失由 4 个元素组成:两个不同的 SSE 损失 loss_u、两个不同数据集上的 loss_b、岭惩罚和特定距离 D。如果我使用“距离 ==“MM””,则代码有效。但是,'distance == "MMD"' 存在错误。我使用内核“kernlab::kmmd”中的外部 rcpp 函数。问题是“Xb %*% beta”是一个 MulExpression。我不知道是否应该将其转换为数字(但如何?)或者是否无法使用 rcpp 函数。

deb_reg <- function(Xu, Yu, Xb, Yb, beta, lambda = 0, theta = 0.5, alpha = 0, distance = "MM") {
      n <- nrow(Xu)
      m <- nrow(Xb)
      ridge <- lambda * sum(beta^2)
      loss_u <- sum((Yu - Xu %*% beta)^2) * ( theta/ n )
      loss_b <- sum((Yb - Xb %*% beta)^2) * ( (1-theta)/ m )

      if(distance == "MM"){
        D <- alpha * ( mean(Yu) - mean(Xb %*% beta) )^2

      } else if(distance == "MMD"){
        y <- as.numeric(Yu)
        # print(beta)
        x <- Xb %*% beta
        # D <- alpha * EasyMMD::MMD(y, x)
        MMD <- kernlab::kmmd(as.matrix(y), as.matrix(x))
        D <- alpha * sum(MMD@mmdstats)

      } else{
        D <- 0

      }
      obj <- loss_u + loss_b + ridge +  D
      return(obj)
}

p <- ncol(X_unbiased)
beta <- Variable(p)
obj <- deb_reg(Xu = X_unbiased, Yu = Y_unbiased, Xb = X_biased, Yb = Y_biased, beta, 
               lambda = 0.1, theta=0.5, alpha = 10, distance = "MMD")
prob <- Problem(Minimize(obj))
result <- solve(prob)
4

0 回答 0