16

我正在使用 MySQLdb 和 Python。我有一些基本的查询,例如:

c=db.cursor()
c.execute("SELECT id, rating from video")
results = c.fetchall()

我需要将“结果”作为 NumPy 数组,并且我希望节省内存消耗。似乎逐行复制数据效率极低(需要双倍的内存)。有没有更好的方法将 MySQLdb 查询结果转换为 NumPy 数组格式?

我希望使用 NumPy 数组格式的原因是因为我希望能够轻松地对数据进行切片和切块,而在这方面,python 似乎对多维数组不太友好。

e.g. b = a[a[:,2]==1] 

谢谢!

4

3 回答 3

24

该方案使用了 Kieth 的fromiter技术,但更直观地处理 SQL 结果的二维表结构。此外,它通过避免 python 数据类型中的所有重塑和扁平化来改进 Doug 的方法。使用结构化数组,我们几乎可以直接从 MySQL 结果读取到 numpy,几乎完全删除了 python 数据类型。我说“几乎”是因为fetchall迭代器仍然产生 python 元组。

虽然有一个警告,但这并不是什么大问题。您必须提前知道列的数据类型和行数。

知道列类型应该很明显,因为您知道查询大概是什么,否则您总是可以使用 curs.description 和 MySQLdb.FIELD_TYPE.* 常量的映射。

知道行数意味着您必须使用客户端游标(这是默认设置)。我对 MySQLdb 和 MySQL 客户端库的内部结构知之甚少,但我的理解是,当使用客户端游标时,整个结果都会被提取到客户端内存中,尽管我怀疑实际上涉及到一些缓冲和缓存。这意味着对结果使用双倍内存,一次用于游标副本,一次用于数组副本,因此如果结果集很大,最好尽快关闭游标以释放内存。

严格来说,您不必提前提供行数,但这样做意味着数组内存是预先分配一次的,而不是随着更多行从迭代器进来而不断调整大小,这意味着提供巨大的性能提升。

有了这个,一些代码

import MySQLdb
import numpy

conn = MySQLdb.connect(host='localhost', user='bob', passwd='mypasswd', db='bigdb')
curs = conn.cursor() #Use a client side cursor so you can access curs.rowcount
numrows = curs.execute("SELECT id, rating FROM video")

#curs.fetchall() is the iterator as per Kieth's answer
#count=numrows means advance allocation
#dtype='i4,i4' means two columns, both 4 byte (32 bit) integers
A = numpy.fromiter(curs.fetchall(), count=numrows, dtype=('i4,i4'))

print A #output entire array
ids = A['f0'] #ids = an array of the first column
              #(strictly speaking it's a field not column)
ratings = A['f1'] #ratings is an array of the second colum

有关如何指定列数据类型和列名的信息,请参阅 dtype 的 numpy 文档和上面有关结构化数组的链接。

于 2013-08-15T17:26:50.460 回答
16

fetchall方法实际上返回一个迭代器,numpy 有fromiter方法从一个迭代器初始化一个数组。因此,根据表中的数据,您可以轻松地将两者结合起来,或者使用适配器生成器。

于 2011-08-15T05:49:01.690 回答
6

NumPy 的fromiter方法在这里似乎最好(如 Keith 的回答,在此之前)。

使用fromiter将通过调用 MySQLdb 游标方法返回的结果集重新转换为 NumPy 数组很简单,但有几个细节可能值得一提。

import numpy as NP
import MySQLdb as SQL

cxn = SQL.connect('localhost', 'some_user', 'their_password', 'db_name')
c = cxn.cursor()
c.execute('SELECT id, ratings from video')

# fetchall() returns a nested tuple (one tuple for each table row)
results = cursor.fetchall()

# 'num_rows' needed to reshape the 1D NumPy array returend by 'fromiter' 
# in other words, to restore original dimensions of the results set
num_rows = int(c.rowcount)

# recast this nested tuple to a python list and flatten it so it's a proper iterable:
x = map(list, list(results))              # change the type
x = sum(x, [])                            # flatten

# D is a 1D NumPy array
D = NP.fromiter(iterable=x, dtype=float, count=-1)  

# 'restore' the original dimensions of the result set:
D = D.reshape(num_rows, -1)

注意fromiter返回一个一维NumPY 数组,

(当然,这是有道理的,因为您可以使用fromiter通过传递count参数来仅返回单个 MySQL 表行的一部分)。

不过,您必须恢复 2D 形状,因此谓词调用游标方法rowcount。以及在最后一行中对reshape的后续调用。

最后,参数count的默认参数是'-1',它只是检索整个iterable

于 2011-08-15T06:51:21.917 回答