6

假设我有一个函数接受一个参数

def fun(x: Int) = x

基于此,我想生成一个具有相同调用约定的新函数,但这会在委托给原始函数之前对其参数进行一些转换。为此,我可以

def wrap_fun(f: (Int) => Int) = (x: Int) => f(x * 2)
wrap_fun(fun)(2) // 4

一个人怎么可能做同样的事情,除了任何只有部分参数来共同应用转换的函数?

def fun1(x: Int, y: Int) = x
def fun2(x: Int, foo: Map[Int,Str], bar: Seq[Seq[Int]]) = x

wrap_fun(fun1)(2, 4) // 4
wrap_fun(fun2)(2, Map(), Seq()) // 4

wrap_fun使上述调用起作用的定义如何?

4

4 回答 4

6

这可以相当直接地使用shapeless 的工具来抽象函数的数量,

import shapeless._
import HList._
import Functions._

def wrap_fun[F, T <: HList, R](f : F)
  (implicit
    hl :   FnHListerAux[F, (Int :: T) => R],
    unhl : FnUnHListerAux[(Int :: T) => R, F]) =
      ((x : Int :: T) => f.hlisted(x.head*2 :: x.tail)).unhlisted

val f1 = wrap_fun(fun _)
val f2 = wrap_fun(fun1 _)
val f3 = wrap_fun(fun2 _)

示例 REPL 会话,

scala> f1(2)
res0: Int = 4

scala> f2(2, 4)
res1: Int = 4

scala> f3(2, Map(), Seq())
res2: Int = 4

请注意,您不能立即应用包装函数(如问题所示),而不是通过分配的 val (正如我在上面所做的那样),因为包装函数的显式参数列表将与wrap_fun. 我们可以得到的最接近问题形式的方法是明确命名apply方法,如下所示,

scala> wrap_fun(fun _).apply(2)
res3: Int = 4

scala> wrap_fun(fun1 _).apply(2, 4)
res4: Int = 4

scala> wrap_fun(fun2 _).apply(2, Map(), Seq())
res5: Int = 4

在这里,明确提及的apply语法将第一个应用程序(wrap_fun及其隐式参数列表)与第二个应用程序(转换后的函数及其显式参数列表)分开。

于 2012-04-19T13:30:11.910 回答
6

与 Scala 中的往常一样,还有另一种方法可以实现您想要做的事情。

这是基于第一个参数与composeof的柯里化的结果Function1

def fun1(x : Int)(y : Int) = x
def fun2(x : Int)(foo : Map[Int, String], bar : Seq[Seq[Int]]) = x

def modify(x : Int) = 2*x

REPL 显示的结果类型将是:

fun1: (x: Int)(y: Int)Int
fun2: (x: Int)(foo: Map[Int,String], bar: Seq[Seq[Int]])Int
modify: (x: Int)Int

而不是包装函数fun1and fun2compose从技术上讲,它们现在都是Function1对象。这使您可以进行如下调用:

(fun1 _ compose modify)(2)(5)
(fun2 _ compose modify)(2)(Map(), Seq())

两者都将返回 4。当然,语法不是很好,因为您必须添加_以区分fun1的应用程序和函数对象本身(在这种情况下您要在其上调用compose方法)。

因此,Luigi 关于一般情况下不可能的论点仍然有效,但如果你可以自由地对你的函数进行柯里化,你可以用这种很好的方式来做到这一点。

于 2012-04-19T09:01:44.710 回答
2

由于采用不同数量参数的函数是不同的、不相关的类型,因此一般不能这样做。trait Function1 [-T1, +R] extends AnyRef没有别的了。您将需要为每个 arity 使用单独的方法。

于 2012-04-18T19:39:22.660 回答
1

虽然我投票支持并同意 Luigi 的回答——因为,你知道......他是的;Scala没有直接的、内置的支持这样的东西——值得注意的是,你想要做的事情并非不可能;只是实现起来有点痛苦,而且通常情况下,您最好只根据所需的数量实现一个单独的方法。

也就是说,虽然......我们实际上可以做到这一点HList。如果您有兴趣尝试一下,自然需要获得一个HList实现。我建议使用 Miles Sabin 出色的无形项目及其对HLists. 无论如何,这是一个使用它的示例,它可以完成类似于您似乎正在寻找的东西:

import shapeless._

trait WrapperFunner[T] {
  type Inputs <: HList
  def wrapFun(inputs: Inputs) :  T
}

class WrapsOne extends WrapperFunner[Int] {
  type Inputs = Int :: HNil
  def wrapFun(inputs: Inputs) : Int = {
    inputs match {
      case num :: HNil => num * 2
    }
  }
}

class WrapsThree extends WrapperFunner[String] {
  type Inputs = Int :: Int :: String :: HNil
  def wrapFun(inputs: Inputs) : String = {
    inputs match {
      case firstNum :: secondNum :: str :: HNil => str + (firstNum - secondNum)
    }
  }
}

object MyApp extends App {

  val wo = new WrapsOne
  println(wo.wrapFun(1 :: HNil))
  println(wo.wrapFun(17 :: HNil))
  //println(wo.wrapFun(18 :: 13 :: HNil))  // Would give type error

  val wt = new WrapsThree
  println(wt.wrapFun(5 :: 1 :: "your result is: " :: HNil))
  val (first, second) = (60, 50)
  println(wt.wrapFun(first :: second :: "%s minus %s is: ".format(first, second) :: HNil))
  //println(wt.wrapFun(1 :: HNil))  // Would give type error

}

运行MyApp结果:

2
34
your result is: 4
60 minus 50 is: 10

或者,更接近您的特定情况:

import shapeless._

trait WrapperFunner[T] {
  type Inputs <: HList
  def wrapFun(inputs: Inputs) :  T
}

trait WrapperFunnerBase extends WrapperFunner[Int] {
  // Does not override `Inputs`
  def wrapFun(inputs: Inputs) : Int = {
    inputs match {
      case (num: Int) :: remainder => num
    }
  }
}

class IgnoresNothing extends WrapperFunnerBase {
  type Inputs = Int :: HNil
}

class IgnoresLastTwo extends WrapperFunnerBase {
  type Inputs = Int :: Int :: String :: HNil
}

object MyApp extends App {

  val in = new IgnoresNothing
  println(in.wrapFun(1 :: HNil))
  println(in.wrapFun(2 :: HNil))
  //println(in.wrapFun(3 :: 4 :: HNil))  // Would give type error

  val ilt = new IgnoresLastTwo
  println(ilt.wrapFun(60 :: 13 :: "stupid string" :: HNil))
  println(ilt.wrapFun(43 :: 7  :: "man, that string was stupid..." :: HNil))
  //println(ilt.wrapFun(1 :: HNil))  // Would give type error

}

结果是:

1
2
60
43
于 2012-04-18T20:16:17.970 回答