我正在尝试修改优秀书籍Programming Collective Intelligence提供的朴素贝叶斯分类器的代码,使其适应 GAE 数据存储(提供的代码使用 pysqlite2)。但是尝试这样做时,我遇到了一个难以调试的错误。错误是这样的:
File "C:\Users\CG\Desktop\Google Drive\Sci&Tech\projects\naivebayes\main.py", line 216, in post
sampletrain(nb)
File "C:\Users\CG\Desktop\Google Drive\Sci&Tech\projects\naivebayes\main.py", line 201, in sampletrain
cl.train('Nobody owns the water.','good')
File "C:\Users\CG\Desktop\Google Drive\Sci&Tech\projects\naivebayes\main.py", line 139, in train
self.incf(f,cat)
File "C:\Users\CG\Desktop\Google Drive\Sci&Tech\projects\naivebayes\main.py", line 71, in incf
count=self.fcount(f,cat)
File "C:\Users\CG\Desktop\Google Drive\Sci&Tech\projects\naivebayes\main.py", line 92, in fcount
return float(res)
TypeError: float() argument must be a string or a number
错误在此块中:
def fcount(self,f,cat):
res = db.GqlQuery("SELECT * FROM fc WHERE feature =:feature AND category =:category", feature = f, category = cat).get()
# res=self.con.execute(
# 'select count from fc where feature="%s" and category="%s"'
# %(f,cat)).fetchone()
if res is None: return 0
else:
res = fc.count
return float(res)
# return float(res[0])
如果我放在set_trace()
第 91 行,像这样:
def fcount(self,f,cat):
res = db.GqlQuery("SELECT * FROM fc WHERE feature =:feature AND category =:category", feature = f, category = cat).get()
set_trace()
# res=self.con.execute(
# 'select count from fc where feature="%s" and category="%s"'
# %(f,cat)).fetchone()
if res is None: return 0
else:
res = fc.count
set_trace()
return float(res)
我得到了这个错误轨迹:
File "C:\Users\CG\Desktop\Google Drive\Sci&Tech\projects\naivebayes\main.py", line 224, in post
sampletrain(nb)
File "C:\Users\CG\Desktop\Google Drive\Sci&Tech\projects\naivebayes\main.py", line 209, in sampletrain
cl.train('Nobody owns the water.','good')
File "C:\Users\CG\Desktop\Google Drive\Sci&Tech\projects\naivebayes\main.py", line 147, in train
self.incf(f,cat)
File "C:\Users\CG\Desktop\Google Drive\Sci&Tech\projects\naivebayes\main.py", line 77, in incf
count=self.fcount(f,cat)
File "C:\Users\CG\Desktop\Google Drive\Sci&Tech\projects\naivebayes\main.py", line 95, in fcount
if res is None: return 0
File "C:\Users\CG\Desktop\Google Drive\Sci&Tech\projects\naivebayes\main.py", line 95, in fcount
if res is None: return 0
File "C:\Python27\lib\bdb.py", line 48, in trace_dispatch
return self.dispatch_line(frame)
File "C:\Python27\lib\bdb.py", line 67, in dispatch_line
if self.quitting: raise BdbQuit
BdbQuit
它与 GqlQuery 有关。我想在 Python IDE 中测试代码,一步一步打印变量和查询,试图找出问题出在哪里。但是当我尝试在 python IDE 中执行此操作时,我会收到错误消息(如"ImportError: No module named webapp2"
)。而且我对成功更改它的程序流程不是很熟悉。实际上,我尝试这样做但迷路了:我是一名新手程序员,直到最近我才开始学习 OOP)。在这种情况下找到错误的最佳方法是什么?
预期的答案应该包括这个错误识别。
提前感谢您的帮助!
这里是整个代码:
#!/usr/bin/env python
#
# Copyright 2007 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
import os
import webapp2
import jinja2
from jinja2 import Environment, FileSystemLoader
jinja_environment = jinja2.Environment(autoescape=True,
loader=jinja2.FileSystemLoader(os.path.join(os.path.dirname(__file__), 'templates')))
import random
from google.appengine.ext import db
import re
import math
def set_trace():
import pdb, sys
debugger = pdb.Pdb(stdin=sys.__stdin__,
stdout=sys.__stdout__)
debugger.set_trace(sys._getframe().f_back)
class fc(db.Model):
feature = db.StringProperty(required = True)
category = db.StringProperty(required = True)
count = db.IntegerProperty(required = True)
class cc(db.Model):
category = db.StringProperty(required = True)
count = db.IntegerProperty(required = True)
def getfeatures(doc):
splitter=re.compile('\\W*')
# Split the words by non-alpha characters
words=[s.lower() for s in splitter.split(doc)
if len(s)>2 and len(s)<20]
return dict([(w,1) for w in words])
class classifier:
def __init__(self,getfeatures, filename=None):
# Counts of feature/category combinations
self.fc={}
# Counts of documents in each category
self.cc={}
self.getfeatures=getfeatures
# def setdb(self,dbfile):
# self.con=sqlite.connect('db_file')
# self.con=sqlite3.connect(":memory:")
# self.con.execute('create table if not exists fc(feature,category,count)')
# self.con.execute('create table if not exists cc(category,count)')
def incf(self,f,cat):
count=self.fcount(f,cat)
if count==0:
fc_value = fc(feature = f, category = cat, count = 1)
fc_value.put()
else:
update = db.GqlQuery("SELECT count FROM fc where feature =:feature AND category =:category", feature = f, category = cat).get()
update.count = count + 1
update.put()
# self.con.execute(
# "update fc set count=%d where feature='%s' and category='%s'"
# % (count+1,f,cat))
def fcount(self,f,cat):
res = db.GqlQuery("SELECT * FROM fc WHERE feature =:feature AND category =:category", feature = f, category = cat).get()
# res=self.con.execute(
# 'select count from fc where feature="%s" and category="%s"'
# %(f,cat)).fetchone()
if res is None: return 0
else:
res = fc.count
return float(res)
# return float(res[0])
def incc(self,cat):
count=self.catcount(cat)
if count==0:
# self.con.execute("insert into cc values ('%s',1)" % (cat))
cc_value = cc(category = cat, count = 1)
cc_value.put()
else:
update = db.GqlQuery("SELECT count FROM cc where category =:category", category = cat).get()
update.count = count + 1
update.put()
# self.con.execute("update cc set count=%d where category='%s'"
# % (count+1,cat))
def catcount(self,cat):
# res=self.con.execute('select count from cc where category="%s"'
# %(cat)).fetchone()
res = db.GqlQuery("SELECT count FROM cc WHERE category =:category", category = cat).get()
if res is None: return 0
# else: return float(res[0])
else: return float(res)
def categories(self):
# cur = self.con.execute('select category from cc');
cur = db.GqlQuery("SELECT category FROM cc").fetch(999)
return [d[0] for d in cur]
def totalcount(self):
# res=self.con.execute('select sum(count) from cc').fetchone();
all_cc = db.GqlQuery("SELECT * FROM cc").fetch(999)
res = 0
for cc in all_cc:
count = cc.count
res+=count
# res = db.GqlQuery("SELECT sum(count) FROM cc").get()
# if res==None: return 0
if res == 0: return 0
# return res[0]
return res
def train(self,item,cat):
features=self.getfeatures(item)
# Increment the count for every feature with this category
for f in features.keys():
## for f in features:
self.incf(f,cat)
# Increment the count for this category
self.incc(cat)
# self.con.commit()
def fprob(self,f,cat):
if self.catcount(cat)==0: return 0
# The total number of times this feature appeared in this
# category divided by the total number of items in this category
return self.fcount(f,cat)/self.catcount(cat)
def weightedprob(self,f,cat,prf,weight=1.0,ap=0.5):
# Calculate current probability
basicprob=prf(f,cat)
# Count the number of times this feature has appeared in
# all categories
totals=sum([self.fcount(f,c) for c in self.categories()])
# Calculate the weighted average
bp=((weight*ap)+(totals*basicprob))/(weight+totals)
return bp
class naivebayes(classifier):
def __init__(self,getfeatures):
classifier.__init__(self, getfeatures)
self.thresholds={}
def docprob(self,item,cat):
features=self.getfeatures(item)
# Multiply the probabilities of all the features together
p=1
for f in features: p*=self.weightedprob(f,cat,self.fprob)
return p
def prob(self,item,cat):
catprob=self.catcount(cat)/self.totalcount()
docprob=self.docprob(item,cat)
return docprob*catprob
def setthreshold(self,cat,t):
self.thresholds[cat]=t
def getthreshold(self,cat):
if cat not in self.thresholds: return 1.0
return self.thresholds[cat]
def classify(self,item,default=None):
probs={}
# Find the category with the highest probability
max=0.0
for cat in self.categories():
probs[cat]=self.prob(item,cat)
if probs[cat]>max:
max=probs[cat]
best=cat
# Make sure the probability exceeds threshold*next best
for cat in probs:
if cat==best: continue
if probs[cat]*self.getthreshold(best)>probs[best]: return default
return best
def sampletrain(cl):
cl.train('Nobody owns the water.','good')
cl.train('the quick rabbit jumps fences','good')
cl.train('buy pharmaceuticals now','bad')
cl.train('make quick money at the online casino','bad')
cl.train('the quick brown fox jumps','good')
class MainHandler(webapp2.RequestHandler):
def get(self):
template_values = {"given_sentence":'put a name here'}
template = jinja_environment.get_template('index.html')
self.response.out.write(template.render(template_values))
def post(self):
nb = naivebayes(getfeatures)
sampletrain(nb)
given_sentence = self.request.get("given_sentence")
spam_result = nb.classify(given_sentence)
submit_button = self.request.get("submit_button")
if submit_button:
self.redirect('/test_result?spam_result=%s&given_sentence=%s' % (spam_result, given_sentence))
class test_resultHandler(webapp2.RequestHandler):
def get(self):
spam_result = self.request.get("spam_result")
given_sentence = self.request.get("given_sentence")
test_result_values = {"spam_result": spam_result,
"given_sentence": given_sentence}
template = jinja_environment.get_template('test_result.html')
self.response.out.write(template.render(test_result_values))
app = webapp2.WSGIApplication([('/', MainHandler), ('/test_result', test_resultHandler)],
debug=True)