8

我正在使用对象检测 API 使用不同的数据集进行训练,我想知道是否有可能在训练期间获得到达网络的样本图像。

我问这个是因为我试图找到一个很好的数据增强选项组合(这里是选项),但是添加它们的结果更糟。在训练中查看到达网络的内容将非常有帮助。

另一个问题是是否可以让 API 帮助平衡类,以防传递的数据集使它们不平衡。

谢谢!

4

1 回答 1

0

对的,这是可能的。简而言之,您需要获取一个 tf.data.Dataset 的实例。然后,您可以对其进行迭代并将网络输入数据作为 NumPy 数组获取。使用 PIL 或 OpenCV 将其保存到图像文件是微不足道的。

假设你使用 TF2 的伪代码是这样的:

ds = ... get dataset object somehow

sample_num = 0
for features, _ in ds:
    images = features[fields.InputDataFields.image]  # is a [batch_size, H, W, C] float32 tensor with preprocessed images
    batch_size = images.shape[0]
    for i in range(batch_size):
        image = np.array(images[i] * 255).astype(np.uint8)  # assuming input data is only scaled to [0..1]
        cv2.imwrite(output_path, image)

    sample_num += 1
    if sample_num >= MAX_SAMPLES:
        break

这里的诀窍是获取 Dataset 实例。谷歌对象检测API非常复杂,但我想你应该从train_input这里调用函数开始:https ://github.com/tensorflow/models/blob/3c8b6f1e17e230b68519fd8d58c4dd9e9570d789/research/object_detection/inputs.py#L763

它需要描述训练、train_input 和模型的管道配置子部分。

您可以在此处找到有关如何使用管道的一些代码片段:Dynamically Editing Pipeline Config for Tensorflow Object Detection

import argparse

import tensorflow as tf
from google.protobuf import text_format
from object_detection.protos import pipeline_pb2


def parse_arguments():                                                                                                                                                                                                                                                
    parser = argparse.ArgumentParser(description='')                                                                                                                                                                                                                  
    parser.add_argument('pipeline')                                                                                                                                                                                                                                   
    parser.add_argument('output')                                                                                                                                                                                                                                     
    return parser.parse_args()                                                                                                                                                                                                                                        


def main():                                                                                                                                                                                                                                                           
    args = parse_arguments()                                                                                                                                                                                                                                          
    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()                                                                                                                                                                                                          

    with tf.gfile.GFile(args.pipeline, "r") as f:                                                                                                                                                                                                                     
        proto_str = f.read()                                                                                                                                                                                                                                          
        text_format.Merge(proto_str, pipeline_config)   
于 2021-02-24T12:00:03.490 回答