我知道这个问题很老,但如果有人想做类似的事情,请扩展ahmedhosny 的答案:
新的 tensorflow 数据集 API 能够使用 python 生成器创建数据集对象,因此与 scikit-learn 的 KFold 一起,一个选项可以是从 KFold.split() 生成器创建数据集:
import numpy as np
from sklearn.model_selection import LeaveOneOut,KFold
import tensorflow as tf
import tensorflow.contrib.eager as tfe
tf.enable_eager_execution()
from sklearn.datasets import load_iris
data = load_iris()
X=data['data']
y=data['target']
def make_dataset(X_data,y_data,n_splits):
def gen():
for train_index, test_index in KFold(n_splits).split(X_data):
X_train, X_test = X_data[train_index], X_data[test_index]
y_train, y_test = y_data[train_index], y_data[test_index]
yield X_train,y_train,X_test,y_test
return tf.data.Dataset.from_generator(gen, (tf.float64,tf.float64,tf.float64,tf.float64))
dataset=make_dataset(X,y,10)
然后可以在基于图的张量流中或使用急切执行来遍历数据集。使用急切执行:
for X_train,y_train,X_test,y_test in tfe.Iterator(dataset):
....