1

我正在尝试使用示例数据拟合示例概率图形模型。在模型中拟合数据时,我遇到类型错误。示例代码如下:

import numpy as np
import pandas as pd
from pgmpy.models import BayesianModel

data = np.random.uniform(low=0, high=2, size=(1000, 4)).astype('float')
data

data = pd.DataFrame(data, columns=['cost', 'quality',
'location',
'no_of_people'])

train = data[:750]
test = data[750:].drop('no_of_people', axis=1)

restaurant_model = BayesianModel(
[('location', 'cost'),
('quality', 'cost'),
('location', 'no_of_people'),
('cost', 'no_of_people')])

restaurant_model.fit(train)

我遇到以下错误:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-173-8e3a85cb8b56> in <module>()
----> 1 restaurant_model.fit(train)

C:\Users\pranav.waila\AppData\Local\Continuum\Anaconda3\lib\site-packages\pgmpy\models\BayesianModel.py in fit(self, data, estimator_type)
    568         estimator = estimator_type(self, data)
    569 
--> 570         cpds_list = estimator.get_parameters()
    571         self.add_cpds(*cpds_list)
    572 

C:\Users\pranav.waila\AppData\Local\Continuum\Anaconda3\lib\site-packages\pgmpy\estimators\MLE.py in get_parameters(self)
     64                 state_counts = state_counts.reindex(sorted(state_counts.index))
     65                 cpd = TabularCPD(node, self.node_card[node],
---> 66                                  state_counts.values[:, np.newaxis])
     67                 cpd.normalize()
     68                 parameters.append(cpd)

C:\Users\pranav.waila\AppData\Local\Continuum\Anaconda3\lib\site-packages\pgmpy\factors\CPD.py in __init__(self, variable, variable_card, values, evidence, evidence_card)
    137             raise TypeError("Values must be a 2D list/array")
    138 
--> 139         super(TabularCPD, self).__init__(variables, cardinality, values.flatten('C'))
    140 
    141     def __repr__(self):

C:\Users\pranav.waila\AppData\Local\Continuum\Anaconda3\lib\site-packages\pgmpy\factors\Factor.py in __init__(self, variables, cardinality, values)
     98 
     99         if values.dtype != int and values.dtype != float:
--> 100             raise TypeError("Values: Expected type int or type float, got ", values.dtype)
    101 
    102         if len(cardinality) != len(variables):

TypeError: ('Values: Expected type int or type float, got ', dtype('int64'))
4

0 回答 0