4

我有一个成功的 weka 分类器(在 scala 中)。我正在尝试实现第二个分类器与它一起工作。

在训练期间,我正在为第一个分类器设置实例,如下所示:

val data = new Instances("Actions",attrs,10)

这工作正常。然而,简单地添加以下行会破坏程序。

val tagdata = new Instances("Tags",tagattrs,10)

添加此行时,程序会在第一个分类器更新时崩溃 从字面上看,工作和非工作之间的唯一变化是对该行的添加。Tagdata 从未使用过(所有使用它的内容都已被注释掉,对于那些阅读我最后一个问题的人),并且程序甚至不会在添加的行上崩溃。这使得创建第二组实例似乎有一些奇怪的副作用。分类器不能并排训练吗?

这是它抛出的错误:

java.lang.ArrayIndexOutOfBoundsException: 75
at weka.estimators.DiscreteEstimator.addValue(DiscreteEstimator.java:89)
at weka.classifiers.bayes.NaiveBayes.updateClassifier(NaiveBayes.java:322)
at Parser$.addInstance$1(Parser.scala:253)
...

这是一个工作子样本(只要您在 ./lib/ 中有 weka.jar,在 ./resources/ 中有一个 conll 格式的测试文件)。取消注释该val tagdata...行将产生上述问题。

import scala.collection.mutable.Queue
import scala.collection.mutable.ListBuffer
import scala.collection.mutable.HashMap
import scala.collection.mutable.MultiMap
import scala.collection.mutable.MutableList
import scala.collection.mutable.Stack
import scala.actors.Actor
import scala.actors.Actor._
import scala.io.Source
import java.io.File
import scala.util.Random
import weka.core._
import weka.classifiers.bayes.NaiveBayesUpdateable
import java.net.URL

object Parser{
  //TypeDefs
  type Arc = (Int, String) //Index, Label
  type Tag = (String, Double) //Tag, probability
  type Token = (Int, String, Int, String) //Index, Word, Head, Label
  //Running params
  val testing = true
  var training = true
  var parsing = true
  val testFile = "test.conll"
  val chooser = new wekaChooser()
  //Data structs
  val storedTokens = MutableList[Token]()
  val buffer = Queue[Token]()
  var stack = Stack[Token]()
  val arcs = HashMap[Int, Arc]()
  val tags = HashMap[Int, String]()
  val classifier = new NaiveBayesUpdateable()
  val tagclassifier = new NaiveBayesUpdateable()
  val tagFrequencies = HashMap[String, List[Tag]]()
  val supplementalTagFrequencies = Map("eat" -> List(("VB", 1.0)), "beans" ->     List(("NN", 1.0)))
  val contextLists = Map("unknown" -> List(), "confirmation" -> List(("UH", 0.8), "right",     "correct", "right", "yes"))
  val context = ListBuffer[Tag]()
  val attrs = new FastVector(4)
  val tagattrs = new FastVector(4)
  var counter = 1

  def main(args: Array[String]){
    if(training) train(testFile)
  }

  def resetLists()={
    buffer.clear
    stack.clear
    stack.push((0, "ROOT", 0, "ROOT"))
    arcs.clear  
    storedTokens.clear
  }

  def applyAction(action:String, buffer:Queue[Token], stack:Stack[Token]):Unit= {
    action match{    
    case "QUIT" => {
      println("Could not parse :(")
      buffer.clear
      stack.clear
    }
    case "SHIFT" => {
      stack.push(buffer.front)
      buffer.dequeue
    }
    case "REDUCE" => {
     // tags.put(stack.top._1, 
       storedTokens += stack.pop
    }
    case x =>{
      x.split("~")match{
        case Array(action, label) => 
          if(action=="LEFTARC"){
            arcs.put(stack.top._1, (buffer.front._1, label))
            storedTokens += stack.pop
          }
          if(action=="RIGHTARC"){
            arcs.put(buffer.front._1, (stack.top._1, label))
            applyAction("SHIFT", buffer, stack)
          }}}}
      }

def train(testFile:String)={
  var wordsmap = HashMap[Int, String]()

def leftDependents(tokenindex:Int) = {
  arcs.filter(_._2._1==tokenindex).filterKeys(_<tokenindex).keys.toList
}

def addInstance(action:String)={
 val data = new Instances("Actions",attrs,10)
//     val tagdata = new Instances("Tags",tagattrs,10)
 data.setClassIndex(2)
 //tagdata.setClassIndex(2)
 val instance = new Instance(3)
//     val taginstance = new Instance(3)
 instance.setDataset(data)     
 //taginstance.setDataset(tagdata)     

 val wordshead = wordsmap.size match{
   case 0 => "NIL"
   case _ => {
     wordsmap(buffer.front._1)
   }
 }

 List(stack.top._2, buffer.front._2, action).zipWithIndex.foreach{
   case (attrString:String, index)=> attrs.elementAt(index) match{
     case attr:Attribute => {
       instance.setValue(attr,attrString)
        }
     }         
   }

 println(instance)
  //   println(taginstance)

//     tagclassifier.updateClassifier(taginstance)
 classifier.updateClassifier(instance)
 }

def setupClassifier = {
 val acts = List("LEFTARC~", "RIGHTARC~")  
 val List(words:List[String], cpostags:List[String], deprels:List[String]) =   Source.fromURL(getClass.getResource("resources/"+testFile)).getLines.toList.map{
   line => line.split("\t") match{
     case Array(id, word, _, cpostag, _, _, head, deprel, _, _) => (word, cpostag, deprel)
     case _ => ("","","")
   }
 }.unzip3.productIterator.toList.map{
   case x:List[String] => (x.toSet++Set("ROOT", "NIL")).toList
   case _ => List()
 }

 val classWords = new FastVector(words.length)
 for(word <- words) classWords.addElement(word)
 val classTags = new FastVector(cpostags.length)
 for(tag <- cpostags) classTags.addElement(tag)

 val stacktag = new Attribute("stacktag", classTags)
 val buffertag = new Attribute("buffertag", classTags)
 val wordattr = new Attribute("word", classWords)

 val fvaction = new FastVector(deprels.length*2+2)
 fvaction.addElement("SHIFT")
 fvaction.addElement("REDUCE")
 for(act <- acts; deprel <- deprels) fvaction.addElement(act+deprel)     
 val action = new Attribute("action", fvaction)

 attrs.addElement(stacktag)
 attrs.addElement(buffertag)
 attrs.addElement(action)     

 tagattrs.addElement(stacktag)
 tagattrs.addElement(wordattr)
 tagattrs.addElement(buffertag)

 val structure =  new Instances("Actions", attrs, 10)
 structure.setClassIndex(2)
 classifier.buildClassifier(structure)

 val tagstructure =  new Instances("Tags", tagattrs, 10)
 tagstructure.setClassIndex(2)
 tagclassifier.buildClassifier(tagstructure)
 }    
//The Code
println("training")
//Based on Malt's oracle prediction training. Assisted in finding that by Rehj.
val wordtaglist = HashMap[String, ListBuffer[String]]()
resetLists
setupClassifier    
Source.fromURL(getClass.getResource("resources/"+testFile)).getLines.foreach{
 line =>{
   if(!line.isEmpty)
     line.split("\t") match{
       case Array(id, word, _, cpostag, _, _, head, deprel, _, _) =>{
         wordtaglist.getOrElseUpdate(word, ListBuffer[String]())+=cpostag //update tag counts
         buffer+=((id.toInt, cpostag, head.toInt, deprel))
         wordsmap+=(id.toInt -> word)
         while(buffer.nonEmpty){
           val action = {
             if(!(stack.top._1==0) && (stack.top._3 == buffer.front._1))
               "LEFTARC~"+stack.top._4
             else if(buffer.front._3 == stack.top._1)
               "RIGHTARC~"+buffer.front._4
             else if(leftDependents(buffer.front._1).nonEmpty && leftDependents(buffer.front._1)(0) < stack.top._1)
               "REDUCE"
             else if(buffer.front._3 < stack.top._1) //&& buffer.front._3 is not root or handling roots normally
               "REDUCE"
             else "SHIFT"
           }
           addInstance(action)
           applyAction(action, buffer, stack)
         }
       }
       case _ => println("line didn't match")
     }else {
       resetLists
       wordsmap.clear
     }
 }
}
resetLists
tagFrequencies++=wordtaglist.map{ case(k,v) => (k, v.groupBy(x=>x).map(x=>((x._1,  x._2.length*1.0/v.length))).toList)}
training = false
}
}
4

0 回答 0