第05课:准备训练数据

第05课:准备训练数据

终于要开始训练识别熊猫的模型了, 第一步是准备好训练数据,这里有三件事情要做:

  • 收集一定数量的熊猫图片。
  • 将图片中的熊猫用矩形框标注出来。
  • 将原始图片和标注文件转换为TFRecord格式的文件。

Data Labeling

收集熊猫的图片和标注熊猫位置的工作称之为“Data Labeling”,这可能是整个机器学习领域内最低级、最机械枯燥的工作了,有时候大量的 Data Labeling 工作会外包给专门的 Data Labeling 公司做, 以加快速度和降低成本。

当然我们不会把这个工作外包给别人,要从最底层的工作开始!

收集熊猫图片倒不是太难,从谷歌和百度图片上收集 200 张熊猫的图片,应该足够训练一个可用的识别模型了。

然后需要一些工具来做标注,我使用的是 Mac 版的 RectLabel,常用的还有 LabelImgLabelMe 等。

LabelImg 可以生成 PASCAL VOC 格式的标注文件,而 LabelME 是一个基于 Web 的标注工具,可以实现多人同时进行标注。

RectLabel 标注时的界面大概是这样的:

enter image description here

当我们标注完成的时候,它会在 annotations 目录下生产和图片文件名相同的后缀名为 .json 的标注文件。

打开一个标注文件,其内容大概是这样的:

    {
      "filename" : "61.jpg",
      "folder" : "panda_images",
      "image_w_h" : [
        453,
        340
      ],
      "objects" : [
        {
          "label" : "panda",
          "x_y_w_h" : [
            90,
            104,
            364,
            233
          ]
        }
      ]
    }
  • image_w_h:图片的宽和高。
  • objects:图片的中的物体信息、数组。
  • label:在标注的时候指定的物体名称。
  • x_y_w_h:物体位置的矩形框:(xmin、ymin、width、height)。

接下来要做的是耐心的在这 200 张图片上面标出熊猫的位置,这个稍微要花点时间,可以在 这里 找已经标注好的图片数据。

生成 TFRecord

接下来需要一点 Python 代码来将图片和标注文件生成为 TFRecord 文件,TFRecord 文件是由很多tf.train.Example对象序列化以后组成的,先写由一个单独的图片文件生成tf.train.Example对象的函数:

    def create_sample(image_filename, data_dir):
        image_path = os.path.join(data_dir, image_filename)
        annotation_path = os.path.join(data_dir, 'annotations', os.path.splitext(image_filename)[0] + ".json")
        with tf.gfile.GFile(image_path, 'rb') as fid:
            encoded_jpg = fid.read()
        encoded_jpg_io = io.BytesIO(encoded_jpg)
        with open(annotation_path) as fid:
            image_annotation = json.load(fid)
        width = image_annotation['image_w_h'][0]
        height = image_annotation['image_w_h'][1]
        xmins = []
        ymins = []
        xmaxs = []
        ymaxs = []
        classes = []
        classes_text = []

        for obj in image_annotation['objects']:
            classes.append(1)
            classes_text.append('panda')
            box = obj['x_y_w_h']
            xmins.append(float(box[0]) / width)
            ymins.append(float(box[1]) / height)
            xmaxs.append(float(box[0] + box[2] - 1) / width)
            ymaxs.append(float(box[1] + box[3] - 1) / height)

        filename = image_annotation['filename'].encode('utf8')
        tf_example = tf.train.Example(features=tf.train.Features(feature={
            'image/height': dataset_util.int64_feature(height),
            'image/width': dataset_util.int64_feature(width),
            'image/filename': dataset_util.bytes_feature(filename),
            'image/source_id': dataset_util.bytes_feature(filename),
            'image/encoded': dataset_util.bytes_feature(encoded_jpg),
            'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
            'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
            'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
            'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
            'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
            'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
            'image/object/class/label': dataset_util.int64_list_feature(classes),
        }))
        return tf_example

在这里简单说明一下:

  • 通过图片文件名找到对应的标注文件,并读入标注信息。
  • 因为图片中标注的物体都是熊猫,用数字 1 来代表,所以 class 数组里的元素值都为 1,class_text数组的里的元素值都为‘panda’。
  • Object Detection API 里面接受的矩形框输入格式为 (xmin, ymin, xmax, ymax) 和标注文件的 (xmin, ymin, width, height) 不一样,所以要做一下转换。同时需要将这些值归一化:将数值投影到 (0, 1] 的区间内。
  • 将特征组成{特征名:特征值}的 dict 作为参数来创建tf.train.Example

接下来将tf.train.Example对象序列化,我们写一个可以由图片文件列表生成对应 TFRecord 文件的的函数:

    def create_tf_record(example_file_list, data_dir, output_file_path):
        writer = tf.python_io.TFRecordWriter(output_file_path)
        for filename in example_file_list:
            tf_example = create_sample(filename, data_dir)
            writer.write(tf_example.SerializeToString())
        writer.close()

依次调用create_sample函数然后将生成的tf.train.Example对象依次序列化即可。

最后需要将数据集切分为训练集合测试集,将图片文件打乱,然后按照 7:3 的比例进行切分:

    random.seed(42)
    random.shuffle(all_examples)
    num_examples = len(all_examples)
    num_train = int(0.7 * num_examples)
    train_examples = all_examples[:num_train]
    val_examples = all_examples[num_train:]
    create_tf_record(train_examples, data_dir, os.path.join(output_dir, 'train.record'))
    create_tf_record(val_examples, data_dir, os.path.join(output_dir, 'val.record'))

写完这个脚本以后,最好再写一个测试用例来验证这个脚本,因为我们将会花很长的时间来训练,到时候再发现脚本有 bug 就太浪费时间了,我们主要测试create_sample方法有没有根据输入数据生成正确的tf.train.Example对象:

    def test_dict_to_tf_example(self):
        image_file = '61.jpg'
        data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'test_data')
        example = create_sample(image_file, data_dir)

        self._assertProtoEqual(
            example.features.feature['image/height'].int64_list.value, [340])
        self._assertProtoEqual(
            example.features.feature['image/width'].int64_list.value, [453])
        self._assertProtoEqual(
            example.features.feature['image/filename'].bytes_list.value,
            [image_file])
        self._assertProtoEqual(
            example.features.feature['image/source_id'].bytes_list.value,
            [image_file])
        self._assertProtoEqual(
            example.features.feature['image/format'].bytes_list.value, ['jpeg'])
        self._assertProtoEqual(
            example.features.feature['image/object/bbox/xmin'].float_list.value,
            [90.0 / 453])
        self._assertProtoEqual(
            example.features.feature['image/object/bbox/ymin'].float_list.value,
            [104.0/340])
        self._assertProtoEqual(
            example.features.feature['image/object/bbox/xmax'].float_list.value,
            [1.0])
        self._assertProtoEqual(
            example.features.feature['image/object/bbox/ymax'].float_list.value,
            [336.0/340])
        self._assertProtoEqual(
            example.features.feature['image/object/class/text'].bytes_list.value,
            ['panda'])
        self._assertProtoEqual(
            example.features.feature['image/object/class/label'].int64_list.value,
            [1])

可以在这里找到全部代码。

完成之后运行脚本,传入图片和标注的文件夹路径和输出文件路径:

python create_tf_record.py --image_dir=PATH_OF_IMAGE_SET --output_dir=OUTPUT_DIR

执行完成后会在由output_dir参数指定的目录生成train.recordval.record文件, 分别为训练集和测试集。

生成 label map 文件

最后还需要一个 label map 文件,很简单,因为我们只有一种物体:熊猫

label_map.pbtxt:


    item {
      id: 1
      name: 'panda'
    }

训练一个熊猫识别模型所需要的训练数据就准备完了,接下来开始在 GPU 主机上面开始训练。

上一篇
下一篇
目录