1

Function common_precision takes two numpy arrays, say x and y. I want to make sure that they are in the same and the highest precision. It seems that relational comparison of dtypes does something to what I want, but:

  1. I don't know what it actually compares
  2. It thinks that numpy.int64 < numpy.float16, which I'm not sure if I agree

    def common_precision(x, y):
        if x.dtype > y.dtype:
           y = y.astype(x.dtype)
        else:
           x = x.astype(y.dtype)
        return (x, y)

Edited: Thanks to kennytm's answer I found that NumPy's find_common_type does exactly what I wanted.


    def common_precision(self, x, y):        
        dtype = np.find_common_type([x.dtype, y.dtype], [])
        if x.dtype != dtype: x = x.astype(dtype)
        if y.dtype != dtype: y = y.astype(dtype)       
        return x, y
4

1 回答 1

0

x.dtype > y.dtypemeany.dtype 可以转换为 x.dtype( && x.dtype != y.type),所以:

>>> numpy.dtype('i8') < numpy.dtype('f2')
False
>>> numpy.dtype('i8') > numpy.dtype('f2')
False

float16 和 int64 根本不兼容。您可以提取一些信息,例如:

>>> numpy.dtype('f2').kind
'f'
>>> numpy.dtype('f2').itemsize
2
>>> numpy.dtype('i8').kind
'i'
>>> numpy.dtype('i8').itemsize
8

并据此确定您的比较方案。

于 2017-02-10T17:31:55.207 回答