我正在尝试构建一个以稀疏张量列表作为输入的模型。(列表长度等于批量大小)
我使用稀疏张量的原因是我必须将邻接矩阵传递给我的 GNN 模型,而且它非常稀疏。(~99%)
我熟悉使用 pytorch,将稀疏张量输入网络非常容易。
但是我发现我必须使用 tf.data.Dataset 或 keras.utils.Sequence 在 tensorflow 中制作数据集。
但是当我使用稀疏张量列表作为输入时,这些方法会向我抛出错误。
例如,下面的代码使 TypeError
import tensorflow as tf
tf.data.Dataset.from_tensor_slices(sparse_lists)
TypeError: Neither a SparseTensor nor SparseTensorValue:
[<tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x7fbf2e25b5c0>,
<tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x7fbf2c22ada0>,
<tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x7fbf2c22a400>,
<tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x7fbf2c1ed240>,
<tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x7fbf2c1ed390>,
<tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x7fbf2c1ed470>,
<tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x7fbf2c1ed5c0>,
<tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x7fbf2c1ed710>,
<tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x7fbf2c1ed828>,
<tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x7fbf2c1ed940>].
我知道如果我将列表中的所有稀疏张量连接为一个巨大的张量,它将起作用。但是,这不是我的选择,因为我以后必须对稀疏张量使用索引。(如果我将 2D 稀疏张量连接成 3D 稀疏张量,我不能使用如下索引)
Some3DSparseTensor[:10]
此外,这将花费更多时间,因为我必须对 3D 张量进行切片,以便与其他密集网络进行矩阵乘法。
此外,我知道如果我通过索引创建稀疏张量,每批的值会很好,但是每批会花费太多时间。
因此,由于索引、时间问题,我想让 tf.data.Dataset 能够从稀疏张量列表中生成批处理。
有谁能够帮我?:)
长话短说,
我所拥有的:稀疏张量列表(例如 1000000 长度列表)
我需要做的:稀疏张量的批处理列表(例如 1024 长度列表,而不是稀疏连接)