2

我想编写一个 scala 宏,它可以基于带有简单类型检查的映射条目覆盖案例类的字段值。如果原始字段类型和覆盖值类型兼容,则设置新值,否则保留原始值。

到目前为止,我有以下代码:

    import language.experimental.macros
    import scala.reflect.macros.Context

    object ProductUtils {

        def withOverrides[T](entity: T, overrides: Map[String, Any]): T =
            macro withOverridesImpl[T]

        def withOverridesImpl[T: c.WeakTypeTag](c: Context)
                                               (entity: c.Expr[T], overrides: c.Expr[Map[String, Any]]): c.Expr[T] = {
            import c.universe._

            val originalEntityTree = reify(entity.splice).tree
            val originalEntityCopy = entity.actualType.member(newTermName("copy"))

            val originalEntity =
                weakTypeOf[T].declarations.collect {
                    case m: MethodSymbol if m.isCaseAccessor =>
                        (m.name, c.Expr[T](Select(originalEntityTree, m.name)), m.returnType)
                }

            val values =
                originalEntity.map {
                    case (name, value, ctype) =>
                        AssignOrNamedArg(
                            Ident(name),
                            {
                                def reifyWithType[K: WeakTypeTag] = reify {
                                    overrides
                                        .splice
                                        .asInstanceOf[Map[String, Any]]
                                        .get(c.literal(name.decoded).splice) match {
                                            case Some(newValue : K) => newValue
                                            case _                  => value.splice
                                        }
                                }

                                reifyWithType(c.WeakTypeTag(ctype)).tree
                            }
                        )
                }.toList

            originalEntityCopy match {
                case s: MethodSymbol =>
                    c.Expr[T](
                        Apply(Select(originalEntityTree, originalEntityCopy), values))
                case _ => c.abort(c.enclosingPosition, "No eligible copy method!")
            }

        }

    }

像这样执行:

    import macros.ProductUtils

    case class Example(field1: String, field2: Int, filed3: String)

    object MacrosTest {
        def main(args: Array[String]) {
            val overrides = Map("field1" -> "new value", "field2" -> "wrong type")
            println(ProductUtils.withOverrides(Example("", 0, ""), overrides)) // Example("new value", 0, "")
        }
    }

正如你所看到的,我已经设法获得了原始字段的类型,现在想在reifyWithType.

不幸的是,在当前的实现中,我在编译期间收到了警告:

warning: abstract type pattern K is unchecked since it is eliminated by erasure case Some(newValue : K) => newValue

和 IntelliJ 中的编译器崩溃:

Exception in thread "main" java.lang.NullPointerException
    at scala.tools.nsc.transform.Erasure$ErasureTransformer$$anon$1.preEraseAsInstanceOf$1(Erasure.scala:1032)
    at scala.tools.nsc.transform.Erasure$ErasureTransformer$$anon$1.preEraseNormalApply(Erasure.scala:1083)
    at scala.tools.nsc.transform.Erasure$ErasureTransformer$$anon$1.preEraseApply(Erasure.scala:1187)
    at scala.tools.nsc.transform.Erasure$ErasureTransformer$$anon$1.preErase(Erasure.scala:1193)
    at scala.tools.nsc.transform.Erasure$ErasureTransformer$$anon$1.transform(Erasure.scala:1268)
    at scala.tools.nsc.transform.Erasure$ErasureTransformer$$anon$1.transform(Erasure.scala:1018)
    at scala.reflect.internal.Trees$class.itransform(Trees.scala:1217)
    at scala.reflect.internal.SymbolTable.itransform(SymbolTable.scala:13)
    at scala.reflect.internal.SymbolTable.itransform(SymbolTable.scala:13)
    at scala.reflect.api.Trees$Transformer.transform(Trees.scala:2897)
    at scala.tools.nsc.transform.TypingTransformers$TypingTransformer.transform(TypingTransformers.scala:48)
    at scala.tools.nsc.transform.Erasure$ErasureTransformer$$anon$1.transform(Erasure.scala:1280)
    at scala.tools.nsc.transform.Erasure$ErasureTransformer$$anon$1.transform(Erasure.scala:1018)

所以问题是:
* 是否可以将宏中接收的类型与值运行时类型进行类型比较?
* 或者有没有更好的方法来解决这个任务?

4

1 回答 1

0

毕竟我最终得到了以下解决方案:

import language.experimental.macros
import scala.reflect.macros.Context

object ProductUtils {

    def withOverrides[T](entity: T, overrides: Map[String, Any]): T =
        macro withOverridesImpl[T]

    def withOverridesImpl[T: c.WeakTypeTag](c: Context)(entity: c.Expr[T], overrides: c.Expr[Map[String, Any]]): c.Expr[T] = {
        import c.universe._

        val originalEntityTree = reify(entity.splice).tree
        val originalEntityCopy = entity.actualType.member(newTermName("copy"))

        val originalEntity =
            weakTypeOf[T].declarations.collect {
                case m: MethodSymbol if m.isCaseAccessor =>
                    (m.name, c.Expr[T](Select(c.resetAllAttrs(originalEntityTree), m.name)), m.returnType)
            }

        val values =
            originalEntity.map {
                case (name, value, ctype) =>
                    AssignOrNamedArg(
                        Ident(name),
                        {

                            val ruClass = c.reifyRuntimeClass(ctype)
                            val mtag    = c.reifyType(treeBuild.mkRuntimeUniverseRef, Select(treeBuild.mkRuntimeUniverseRef, newTermName("rootMirror")), ctype)
                            val mtree   = Select(mtag, newTermName("tpe"))

                            def reifyWithType[K: c.WeakTypeTag] = reify {

                                def tryNewValue[A: scala.reflect.runtime.universe.TypeTag](candidate: Option[A]): Option[K] =
                                    if (candidate.isEmpty) {
                                        None
                                    } else {
                                        val cc =  c.Expr[Class[_]](ruClass).splice
                                        val candidateValue = candidate.get
                                        val candidateType  = scala.reflect.runtime.universe.typeOf[A]
                                        val expectedType   = c.Expr[scala.reflect.runtime.universe.Type](mtree).splice

                                        val ok = (cc.isPrimitive, candidateValue) match {
                                            case (true, _: java.lang.Integer)   => cc == java.lang.Integer.TYPE
                                            case (true, _: java.lang.Long)      => cc == java.lang.Long.TYPE
                                            case (true, _: java.lang.Double)    => cc == java.lang.Double.TYPE
                                            case (true, _: java.lang.Character) => cc == java.lang.Character.TYPE
                                            case (true, _: java.lang.Float)     => cc == java.lang.Float.TYPE
                                            case (true, _: java.lang.Byte)      => cc == java.lang.Byte.TYPE
                                            case (true, _: java.lang.Short)     => cc == java.lang.Short.TYPE
                                            case (true, _: java.lang.Boolean)   => cc == java.lang.Boolean.TYPE
                                            case (true, _: Unit)                => cc == java.lang.Void.TYPE
                                            case  _                             =>
                                                val args = candidateType.asInstanceOf[scala.reflect.runtime.universe.TypeRefApi].args
                                                if (!args.contains(scala.reflect.runtime.universe.typeOf[Any])
                                                       && !(candidateType =:= scala.reflect.runtime.universe.typeOf[Any]))
                                                    candidateType =:= expectedType
                                                else cc.isInstance(candidateValue)
                                        }

                                        if (ok)
                                            Some(candidateValue.asInstanceOf[K])
                                        else None
                                }

                                tryNewValue(overrides.splice.get(c.literal(name.decoded).splice)).getOrElse(value.splice)
                            }

                            reifyWithType(c.WeakTypeTag(ctype)).tree
                        }
                    )
            }.toList

        originalEntityCopy match {
            case s: MethodSymbol =>
                c.Expr[T](
                    Apply(Select(originalEntityTree, originalEntityCopy), values))
            case _ => c.abort(c.enclosingPosition, "No eligible copy method!")
        }

    }

}

它满足了原始要求:

class ProductUtilsTest extends FunSuite {

    case class A(a: String, b: String)
    case class B(a: String, b: Int)
    case class C(a: List[Int], b: List[String])
    case class D(a: Map[Int, String], b: Double)
    case class E(a: A, b: B)

    test("simple overrides works"){
        val overrides = Map("a" -> "A", "b" -> "B")
        assert(ProductUtils.withOverrides(A("", ""), overrides) === A("A", "B"))
    }

    test("simple overrides works 1"){
        val overrides = Map("a" -> "A", "b" -> 1)
        assert(ProductUtils.withOverrides(B("", 0), overrides) === B("A", 1))
    }

    test("do not override if types do not match"){
        val overrides = Map("a" -> 0, "b" -> List("B"))
        assert(ProductUtils.withOverrides(B("", 0), overrides) === B("", 0))
    }

    test("complex types also works"){
        val overrides = Map("a" -> List(1), "b" -> List("A"))
        assert(ProductUtils.withOverrides(C(List(0), List("")), overrides) === C(List(1), List("A")))
    }

    test("complex types also works 1"){
        val overrides = Map("a" -> List(new Date()), "b" -> 2.0d)
        assert(ProductUtils.withOverrides(D(Map(), 1.0), overrides) === D(Map(), 2.0))
    }

    test("complex types also works 2"){
        val overrides = Map("a" -> A("AA", "BB"), "b" -> 2.0d)
        assert(ProductUtils.withOverrides(E(A("", ""), B("", 0)), overrides) === E(A("AA", "BB"), B("", 0)))
    }

}

不幸的是,由于 Java/Scala 中的类型擦除,在将值更改为新值之前很难强制类型相等,因此您可以执行以下操作:

scala> case class C(a: List[Int], b: List[String])
defined class C

scala> val overrides = Map("a" -> List(new Date()), "b" -> List(1.0))
overrides: scala.collection.immutable.Map[String,List[Any]] = Map(a -> List(Mon Aug 26 15:52:27 CEST 2013), b -> List(1.0))

scala> ProductUtils.withOverrides(C(List(0), List("")), overrides)
res0: C = C(List(Mon Aug 26 15:52:27 CEST 2013),List(1.0))

scala> res0.a.head + 1
java.lang.ClassCastException: java.util.Date cannot be cast to java.lang.Integer
    at scala.runtime.BoxesRunTime.unboxToInt(BoxesRunTime.java:106)
    at .<init>(<console>:14)
    at .<clinit>(<console>)
    at .<init>(<console>:7)
    at .<clinit>(<console>)
    at $print(<console>)
    at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:57)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:606)
    at scala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:734)
    at scala.tools.nsc.interpreter.IMain$Request.loadAndRun(IMain.scala:983)
    at scala.tools.nsc.interpreter.IMain.loadAndRunReq$1(IMain.scala:573)
    at scala.tools.nsc.interpreter.IMain.interpret(IMain.scala:604)
于 2013-08-26T14:01:23.180 回答