4

我正在为大学的一个实验室在 Scala 中实施符号分析。为此,我需要对抽象值进行算术运算,例如Pos, Neg, Zero, NonPos, NonNeg... 所以我必须在这些抽象值上声明方法+, -, *,/等。我实际上不需要成对定义所有内容,但可以定义 、 和 上的操作“核心” PosNeg然后Zero使用上限来定义,例如:

NonPos + Pos = leastUpperBound(Zero + Pos, Neg + Pos)

例如在哪里leastUpperBound(Zero, Neg) = NonPos

在 Scala 中,我使用 case 对象来表示值,并 在每个值leastUpperBound()上都有一个方法。但是我仍然有一些我无法摆脱的代码重复,例如我定义:

case object NonNeg extends Sign {
    def +(other: Sign) = leastUpperBound(Zero + other, Pos + other)
    def -(other: Sign) = ...
    def * = ...
    ...
}

我必须为:

case object NonPos extends Sign {
    def +(other: Sign) = leastUpperBound(Zero + other, Neg + other)
    ...
}

然后再次:

case object NonZero extends Sign {
    def +(other: Sign) = leastUpperBound(Neg + other, Neg + other)
    ...
}

我想知道是否有可能拥有某种“类型工厂”,以便我可以本着以下精神说些什么:

case object NonNeg extends UpperBoundSign[Pos, Zero]

我的直觉是s是不可能的Pos,但我对 Scala 不太熟悉,所以我可能会忘记一些允许我这样做的特性或模式。Zeroobject

有没有人知道删除这个重复?2.10 中的 Scala 宏是否可以很好地解决这个问题?

我希望问题很清楚,谢谢。

编辑:感谢@cmbaxter 的回答和我的一些重构,我想出了一个我喜欢的解决方案。如果有人有兴趣看到它,可以在这里找到:https ://gist.github.com/Ricordel/5553405 。

4

3 回答 3

5

我认为您可能会混淆类型标识符和类的实例。我相信为了获得你想要的功能,你需要定义UpperBoundSign一个抽象类,它带有两个构造函数参数,而不是具有两个类型标识符槽的泛型类型。这是一个过于简单的解决方案选项,可以满足您的需求。如果这完全不是您想要的,我深表歉意:

trait Sign{
  def +(other: Sign):Sign
}

abstract class UpperBoundSign(pos:Sign, neg:Sign) extends Sign{
  def leastUpperBound(pos:Sign, neg:Sign):Sign
  def +(other: Sign) = leastUpperBound(pos + other,  neg + other)
}

case object Pos extends Sign{
  def +(other:Sign) = ...
}

case object Neg extends Sign{
  def +(other:Sign) = ...
}

case object NonNeg extends UpperBoundSign(Pos, Neg){
  def leastUpperBound(pos:Sign, neg:Sign) = ...
}
于 2013-05-09T12:04:32.373 回答
1

好的,对不起,我误解了你的问题无论如何,我尝试了你的代码,它看起来很有趣,但是因为我喜欢简单,所以我试图为你的问题想出一个更简单的解决方案。

object Sign {
  case object Pos extends Sign
  case object Neg extends Sign
  case object Zero extends Sign
  case object Undefined extends Sign
  case object NonPos extends SignSet(Set(Neg, Zero)) {
    override def toString = "NonPos"
  }
  case object NonNeg extends SignSet(Set(Pos, Zero)) {
    override def toString = "NonNeg"
  }
  case object NonZero extends SignSet(Set(Pos, Neg)) {
    override def toString = "NonZero"
  }
  case object AnySign extends SignSet(Set(Pos, Neg, Zero)) {
    override def toString = "AnySign"
  }

  private val signs = List(Pos, Neg, Zero, Undefined, NonPos, NonNeg, NonZero, AnySign)

  private def calc(op: Symbol, s1: Sign, s2: Sign): Sign = {
    val sign = _calc(op, s1, s2)
    signs.find(_ == sign).getOrElse(sign)
  }

  private def _calc(op: Symbol, s1: Sign, s2: Sign): Sign = (op, s1, s2) match {
    case (op, set: SignSet, sign) => set.flatMap(s => _calc(op, s, sign))
    case (op, sign, set: SignSet) => set.flatMap(s => _calc(op, sign, s))
    case (_, Undefined, _) => Undefined
    case (_, _, Undefined) => Undefined

    case ('+, x, y) if x == y => x
    case ('+, x, Zero) => x
    case ('+, Zero, x) => x
    case ('+, Pos, Neg) => SignSet(Pos, Neg, Zero)
    case ('+, Neg, Pos) => SignSet(Pos, Neg, Zero)

    case ('-, x, Neg) => _calc('+, x, Pos)
    case ('-, x, Pos) => _calc('+, x, Neg)
    case ('-, x, Zero) => x

    case ('*, Zero, _) => Zero
    case ('*, Pos, x) => x
    case ('*, Neg, Pos) => Neg
    case ('*, Neg, Neg) => Pos
    case ('*, Neg, Zero) => Zero

    case ('/, _, Zero) => Undefined
    case ('/, x, y) => _calc('*, x, y)
  }
}

sealed trait Sign {
  import Sign.calc
  def +(other: Sign) = calc('+, this, other)
  def -(other: Sign) = calc('-, this, other)
  def *(other: Sign) = calc('*, this, other)
  def /(other: Sign) = calc('/, this, other)
  def flatten: Sign = this
  def |(other: Sign): Sign = other match {
    case sign if sign == this => this
    case SignSet(signs) => SignSet(signs + this)
    case sign => SignSet(this, sign)
  }
}

object SignSet {
  def apply(signs: Set[Sign]) = new SignSet(signs)
  def apply(signs: Sign*) = new SignSet(signs.toSet)
  def unapply(set: SignSet) = Some(set.signs)
}
class SignSet(val signs: Set[Sign]) extends Sign {
  def flatMap(f: Sign => Sign) = SignSet(signs.map(f)).flatten
  override def flatten = signs.map(_.flatten).reduce(_ | _)
  override def |(other: Sign) = other match {
    case SignSet(otherSigns) => SignSet(otherSigns | signs)
    case sign => SignSet(signs + sign)
  }
  override def toString = signs.mkString("SignSet(", ", ", ")")
  def equals(other: SignSet) = signs == other.signs
  override def equals(other: Any) = other match {
    case set: SignSet => equals(set)
    case _ => false
  }
}

import Sign._

println(Pos / NonPos)
println(Pos + Neg)
println(NonZero * Zero)
println(NonZero / NonPos)
println(NonZero - NonZero)
println(NonZero + Zero)
于 2013-05-10T18:01:51.020 回答
0

你可以用集合论试试。

abstract class Sign(name: String) {
  def contains(sign: Sign) = sign eq this
  def +(other: Sign) = findSign(Union(this, other))
  def -(other: Sign) = findSign(Difference(this, other))

  def equals(other: Sign) = (contains(Pos), contains(Neg), contains(Zero)) == (other.contains(Pos), other.contains(Neg), other.contains(Zero))
  override def equals(other: Any) = {
    if(other.isInstanceOf[Sign]) equals(other.asInstanceOf[Sign])
    else false
  }
  override def toString = name
}

case class Union(sign1: Sign, sign2: Sign, name: String = "Union") extends Sign(name) {
  override def contains(sign: Sign) = sign1.contains(sign) || sign2.contains(sign)
}
case class Intersection(sign1: Sign, sign2: Sign, name: String = "Intersection") extends Sign(name) {
  override def contains(sign: Sign) = (sign1.contains(sign) || sign2.contains(sign)) && !(sign1.contains(sign) && sign2.contains(sign))
}
case class Difference(sign1: Sign, sign2: Sign, name: String = "Difference") extends Sign(name) {
  override def contains(sign: Sign) = sign1.contains(sign) && !sign2.contains(sign)
}
case class Negation(sign: Sign, name: String = "Negation") extends Sign(name) {
  override def contains(s: Sign) = !sign.contains(s)
}

case object Zero extends Sign("Zero")
case object Pos extends Sign("Pos")
case object Neg extends Sign("Neg")
val NonPos = Negation(Pos, "NonPos")
val NonNeg = Negation(Neg, "NonNeg")
val NonZero = Negation(Zero, "NonZero")
val AnySign = Union(NonZero, Zero, "AnySign")
val NoSign = Negation(AnySign, "NoSign")

val signs = List(Zero, Pos, Neg, NonPos, NonNeg, NonZero, AnySign, NoSign)
def findSign(sign: Sign) = signs.find(_ == sign).get

println(Pos + Neg)
println(NonNeg - Zero)
println(NonZero + Zero)
println(Pos + Neg + Zero)
println(NonPos - Neg - Zero)
于 2013-05-09T13:45:31.087 回答