python目标检测SSD算法训练部分源码详解

目录
  • 学习前言
  • 讲解构架
  • 模型训练的流程
    • 1、设置参数
    • 2、读取数据集
    • 3、建立ssd网络。
    • 4、预处理数据集
    • 5、框的编码
    • 6、计算loss值
    • 7、训练模型并保存
  • 开始训练

学习前言

……又看了很久的SSD算法,今天讲解一下训练部分的代码。预测部分的代码可以参照https://blog.csdn.net/weixin_44791964/article/details/102496765

讲解构架

本次教程的讲解主要是对训练部分的代码进行讲解,该部分讲解主要是对训练函数的执行过程与执行思路进行详解

训练函数的执行过程大体上分为:

1、设定训练参数。

2、读取数据集。

3、建立ssd网络。

4、预处理数据集。

5、对ground truth实际框进行编码,使其格式符合神经网络的预测结果,便于比较。

6、计算loss值。

7、利用优化器完成梯度下降并保存模型。

在看本次算法前,建议先下载我简化过的源码,配合观看,具体运行方法在开始训练部分
链接:https://pan.baidu.com/s/1MeFsWrv5dAo2Lo6T5ZYsRw
提取码:eo3d

模型训练的流程

本文使用的ssd_vgg_300的源码源于https://github.com/balancap/SSD-Tensorflow,本文对其进行了简化,保留了上一次筛选出的预测部分,还加入了训练部分,便于理顺整个SSD的框架。

1、设置参数

在载入数据库前,首先要设定一系列的参数,这些参数可以分为几个部分。
第一部分是SSD网络中的一些标志参数:

# =========================================================================== #
# SSD Network flags.
# =========================================================================== #
# localization框的衰减比率
tf.app.flags.DEFINE_float(
    'loss_alpha', 1., 'Alpha parameter in the loss function.')
# 正负样本比率
tf.app.flags.DEFINE_float(
    'negative_ratio', 3., 'Negative ratio in the loss function.')
# ground truth处理后,匹配得分高于match_threshold属于正样本
tf.app.flags.DEFINE_float(
    'match_threshold', 0.5, 'Matching threshold in the loss function.')

第二部分是训练时的参数(包括训练效果输出、保存方案等):

# =========================================================================== #
# General Flags.
# =========================================================================== #
# train_dir用于保存训练后的模型和日志
tf.app.flags.DEFINE_string(
    'train_dir', '/tmp/tfmodel/',
    'Directory where checkpoints and event logs are written to.')
# num_readers是在对数据集进行读取时所用的平行读取器个数
tf.app.flags.DEFINE_integer(
    'num_readers', 4,
    'The number of parallel readers that read data from the dataset.')
# 在进行训练batch的构建时,所用的线程数
tf.app.flags.DEFINE_integer(
    'num_preprocessing_threads', 4,
    'The number of threads used to create the batches.')
# 每十步进行一次log输出,在窗口上
tf.app.flags.DEFINE_integer(
    'log_every_n_steps', 10,
    'The frequency with which logs are print.')
# 每600秒存储一次记录
tf.app.flags.DEFINE_integer(
    'save_summaries_secs', 600,
    'The frequency with which summaries are saved, in seconds.')
# 每600秒存储一次模型
tf.app.flags.DEFINE_integer(
    'save_interval_secs', 600,
    'The frequency with which the model is saved, in seconds.')
# 可以使用的gpu内存数量
tf.app.flags.DEFINE_float(
    'gpu_memory_fraction', 0.7, 'GPU memory fraction to use.')

第三部分是优化器参数:

# =========================================================================== #
# Optimization Flags.
# =========================================================================== #
# 优化器参数
# weight_decay参数
tf.app.flags.DEFINE_float(
    'weight_decay', 0.00004, 'The weight decay on the model weights.')
# 使用什么优化器
tf.app.flags.DEFINE_string(
    'optimizer', 'rmsprop',
    'The name of the optimizer, one of "adadelta", "adagrad", "adam",'
    '"ftrl", "momentum", "sgd" or "rmsprop".')
tf.app.flags.DEFINE_float(
    'adadelta_rho', 0.95,
    'The decay rate for adadelta.')
tf.app.flags.DEFINE_float(
    'adagrad_initial_accumulator_value', 0.1,
    'Starting value for the AdaGrad accumulators.')
tf.app.flags.DEFINE_float(
    'adam_beta1', 0.9,
    'The exponential decay rate for the 1st moment estimates.')
tf.app.flags.DEFINE_float(
    'adam_beta2', 0.999,
    'The exponential decay rate for the 2nd moment estimates.')
tf.app.flags.DEFINE_float('opt_epsilon', 1.0, 'Epsilon term for the optimizer.')
tf.app.flags.DEFINE_float('ftrl_learning_rate_power', -0.5,
                          'The learning rate power.')
tf.app.flags.DEFINE_float(
    'ftrl_initial_accumulator_value', 0.1,
    'Starting value for the FTRL accumulators.')
tf.app.flags.DEFINE_float(
    'ftrl_l1', 0.0, 'The FTRL l1 regularization strength.')
tf.app.flags.DEFINE_float(
    'ftrl_l2', 0.0, 'The FTRL l2 regularization strength.')
tf.app.flags.DEFINE_float(
    'momentum', 0.9,
    'The momentum for the MomentumOptimizer and RMSPropOptimizer.')
tf.app.flags.DEFINE_float('rmsprop_momentum', 0.9, 'Momentum.')
tf.app.flags.DEFINE_float('rmsprop_decay', 0.9, 'Decay term for RMSProp.')

第四部分是学习率参数:

# =========================================================================== #
# Learning Rate Flags.
# =========================================================================== #
# 学习率衰减的方式,有固定、指数衰减等
tf.app.flags.DEFINE_string(
    'learning_rate_decay_type',
    'exponential',
    'Specifies how the learning rate is decayed. One of "fixed", "exponential",'
    ' or "polynomial"')
# 初始学习率
tf.app.flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.')
# 结束时的学习率
tf.app.flags.DEFINE_float(
    'end_learning_rate', 0.0001,
    'The minimal end learning rate used by a polynomial decay learning rate.')
tf.app.flags.DEFINE_float(
    'label_smoothing', 0.0, 'The amount of label smoothing.')
# 学习率衰减因素
tf.app.flags.DEFINE_float(
    'learning_rate_decay_factor', 0.94, 'Learning rate decay factor.')
tf.app.flags.DEFINE_float(
    'num_epochs_per_decay', 2.0,
    'Number of epochs after which learning rate decays.')
tf.app.flags.DEFINE_float(
    'moving_average_decay', None,
    'The decay to use for the moving average.'
    'If left as None, then moving averages are not used.')

第五部分是数据集参数:

# =========================================================================== #
# Dataset Flags.
# =========================================================================== #
# 数据集名称
tf.app.flags.DEFINE_string(
    'dataset_name', 'imagenet', 'The name of the dataset to load.')
# 数据集种类个数
tf.app.flags.DEFINE_integer(
    'num_classes', 21, 'Number of classes to use in the dataset.')
# 训练还是测试
tf.app.flags.DEFINE_string(
    'dataset_split_name', 'train', 'The name of the train/test split.')
# 数据集目录
tf.app.flags.DEFINE_string(
    'dataset_dir', None, 'The directory where the dataset files are stored.')
tf.app.flags.DEFINE_integer(
    'labels_offset', 0,
    'An offset for the labels in the dataset. This flag is primarily used to '
    'evaluate the VGG and ResNet architectures which do not use a background '
    'class for the ImageNet dataset.')
tf.app.flags.DEFINE_string(
    'model_name', 'ssd_300_vgg', 'The name of the architecture to train.')
tf.app.flags.DEFINE_string(
    'preprocessing_name', None, 'The name of the preprocessing to use. If left '
    'as `None`, then the model_name flag is used.')
# 每一次训练batch的大小
tf.app.flags.DEFINE_integer(
    'batch_size', 32, 'The number of samples in each batch.')
# 训练图片的大小
tf.app.flags.DEFINE_integer(
    'train_image_size', None, 'Train image size')
# 最大训练次数
tf.app.flags.DEFINE_integer('max_number_of_steps', 50000,
                            'The maximum number of training steps.')

第六部分是微修已有的模型所需的参数:

# =========================================================================== #
# Fine-Tuning Flags.
# =========================================================================== #
# 该部分参数用于微修已有的模型
# 原模型的位置
tf.app.flags.DEFINE_string(
    'checkpoint_path', None,
    'The path to a checkpoint from which to fine-tune.')
tf.app.flags.DEFINE_string(
    'checkpoint_model_scope', None,
    'Model scope in the checkpoint. None if the same as the trained model.')
# 哪些变量不要
tf.app.flags.DEFINE_string(
    'checkpoint_exclude_scopes', None,
    'Comma-separated list of scopes of variables to exclude when restoring '
    'from a checkpoint.')
# 那些变量不训练
tf.app.flags.DEFINE_string(
    'trainable_scopes', None,
    'Comma-separated list of scopes to filter the set of variables to train.'
    'By default, None would train all the variables.')
# 忽略丢失的变量
tf.app.flags.DEFINE_boolean(
    'ignore_missing_vars', False,
    'When restoring a checkpoint would ignore missing variables.')

FLAGS = tf.app.flags.FLAGS

所有的参数的意义我都进行了标注,在实际训练的时候需要修改一些参数的内容,这些参数看起来多,其实只是包含了一个网络训练所有必须的部分:

网络主体参数;

训练时的普通参数(包括训练效果输出、保存方案等);

优化器参数;

学习率参数;

数据集参数;

微修已有的模型的参数设置。

2、读取数据集

在训练流程中,其通过如下函数读取数据集

##########################读取数据集部分#############################
# 选择数据库
dataset = dataset_factory.get_dataset(
    FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)

dataset_factory里面放的是数据集获取和处理的函数,这里面对应了4个数据集, 利用datasets_map存储了四个数据集的处理代码。

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from datasets import cifar10
from datasets import imagenet
from datasets import pascalvoc_2007
from datasets import pascalvoc_2012

datasets_map = {
    'cifar10': cifar10,
    'imagenet': imagenet,
    'pascalvoc_2007': pascalvoc_2007,
    'pascalvoc_2012': pascalvoc_2012,
}

def get_dataset(name, split_name, dataset_dir, file_pattern=None, reader=None):
    """
    给定一个数据集名和一个拆分名返回一个数据集。
    参数:
        name: String, 数据集名称
        split_name: 训练还是测试
        dataset_dir: 存储数据集文件的目录。
        file_pattern: 用于匹配数据集源文件的文件模式。
        reader: tf.readerbase的子类。如果保留为“none”,则使用每个数据集定义的默认读取器。
    Returns:
        数据集
    """
    if name not in datasets_map:
        raise ValueError('Name of dataset unknown %s' % name)
    return datasets_map[name].get_split(split_name,
                                        dataset_dir,
                                        file_pattern,
                                        reader)

我们这里用到pascalvoc_2012的数据,所以当返回datasets_map[name].get_split这个代码时,实际上调用的是:

pascalvoc_2012.get_split(split_name,
						dataset_dir,
						file_pattern,
						reader)

在pascalvoc_2012中get_split的执行过程如下,其中file_pattern = ‘voc_2012_%s_*.tfrecord’,这个名称是训练的图片的默认名称,实际训练的tfrecord文件名称像这样voc_2012_train_001.tfrecord,意味着可以读取这样的训练文件:

def get_split(split_name, dataset_dir, file_pattern=None, reader=None):
    """Gets a dataset tuple with instructions for reading ImageNet.
    Args:
      split_name: 训练还是测试
      dataset_dir: 数据集的位置
      file_pattern: 匹配数据集源时要使用的文件模式。
                    假定模式包含一个'%s'字符串,以便可以插入拆分名称
      reader: TensorFlow阅读器类型。
    Returns:
      数据集.
    """
    if not file_pattern:
        file_pattern = FILE_PATTERN
    return pascalvoc_common.get_split(split_name, dataset_dir,
                                      file_pattern, reader,
                                      SPLITS_TO_SIZES,
                                      ITEMS_TO_DESCRIPTIONS,
                                      NUM_CLASSES)

再进入到pascalvoc_common文件后,实际上就开始对tfrecord的文件进行分割了,通过代码注释我们了解代码的执行过程,其中tfrecord的文件读取就是首先按照keys_to_features的内容进行文件解码,解码后的结果按照items_to_handlers的格式存入数据集

def get_split(split_name, dataset_dir, file_pattern, reader,
              split_to_sizes, items_to_descriptions, num_classes):
    """Gets a dataset tuple with instructions for reading Pascal VOC dataset.
    给定一个数据集名和一个拆分名返回一个数据集。
    参数:
        name: String, 数据集名称
        split_name: 训练还是测试
        dataset_dir: 存储数据集文件的目录。
        file_pattern: 用于匹配数据集源文件的文件模式。
        reader: tf.readerbase的子类。如果保留为“none”,则使用每个数据集定义的默认读取器。
    Returns:
        数据集
    """
    if split_name not in split_to_sizes:
        raise ValueError('split name %s was not recognized.' % split_name)

    # file_pattern是取得的tfrecord数据集的位置
    file_pattern = os.path.join(dataset_dir, file_pattern % split_name)

    # 当没有的时候使用默认reader
    if reader is None:
        reader = tf.TFRecordReader
    # VOC数据集中的文档内容
    keys_to_features = {
        'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
        'image/format': tf.FixedLenFeature((), tf.string, default_value='jpeg'),
        'image/height': tf.FixedLenFeature([1], tf.int64),
        'image/width': tf.FixedLenFeature([1], tf.int64),
        'image/channels': tf.FixedLenFeature([1], tf.int64),
        'image/shape': tf.FixedLenFeature([3], tf.int64),
        'image/object/bbox/xmin': tf.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/ymin': tf.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/xmax': tf.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/ymax': tf.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/label': tf.VarLenFeature(dtype=tf.int64),
        'image/object/bbox/difficult': tf.VarLenFeature(dtype=tf.int64),
        'image/object/bbox/truncated': tf.VarLenFeature(dtype=tf.int64),
    }
    # 解码方式
    items_to_handlers = {
        'image': slim.tfexample_decoder.Image('image/encoded', 'image/format'),
        'shape': slim.tfexample_decoder.Tensor('image/shape'),
        'object/bbox': slim.tfexample_decoder.BoundingBox(
                ['ymin', 'xmin', 'ymax', 'xmax'], 'image/object/bbox/'),
        'object/label': slim.tfexample_decoder.Tensor('image/object/bbox/label'),
        'object/difficult': slim.tfexample_decoder.Tensor('image/object/bbox/difficult'),
        'object/truncated': slim.tfexample_decoder.Tensor('image/object/bbox/truncated'),
    }
    # 将tfrecord上keys_to_features的部分解码到items_to_handlers上
    decoder = slim.tfexample_decoder.TFExampleDecoder(
        keys_to_features, items_to_handlers)

    labels_to_names = None
    if dataset_utils.has_labels(dataset_dir):
        labels_to_names = dataset_utils.read_label_file(dataset_dir)

    return slim.dataset.Dataset(
            data_sources=file_pattern,  # 数据源
            reader=reader,              # tf.TFRecordReader
            decoder=decoder,            # 解码结果
            num_samples=split_to_sizes[split_name], # 17125
            items_to_descriptions=items_to_descriptions,    # 每一个item的描述
            num_classes=num_classes,                        # 种类
            labels_to_names=labels_to_names)

通过上述一系列操作,实际上是返回了一个slim.dataset.Dataset数据集,而一系列函数的调用,实际上是为了调用对应的数据集。

3、建立ssd网络。

建立ssd网络的过程并不复杂,没有许多函数的调用,实际执行过程如果了解ssd网络的预测部分就很好理解,我这里只讲下逻辑:

1、利用ssd_class = ssd_vgg_300.SSDNet获得SSDNet的类

2、替换种类的数量num_classes参数

3、利用ssd_net = ssd_class(ssd_params)建立网络

4、获得先验框

调用的代码如下:

###########################建立ssd网络##############################
# 获得SSD的网络和它的先验框
ssd_class = ssd_vgg_300.SSDNet
# 替换种类的数量num_classes参数
ssd_params = ssd_class.default_params._replace(num_classes=FLAGS.num_classes)
# 成功建立了网络net,替换参数
ssd_net = ssd_class(ssd_params)
# 获得先验框
ssd_shape = ssd_net.params.img_shape
ssd_anchors = ssd_net.anchors(ssd_shape) # 包括六个特征层的先验框

4、预处理数据集

预处理数据集的代码比较长,但是逻辑并不难理解。

1、获得数据集名称。

2、获取数据集处理的函数。

3、利用DatasetDataProviders从数据集中提供数据,进行数据的预加载。

4、获取原始的图片和它对应的label,框ground truth的位置

5、预处理图片标签和框的位置

具体实现的代码如下:

###########################预处理数据集##############################
# preprocessing_name等于ssd_300_vgg
preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name

# 根据名字进行处理获得处理函数
image_preprocessing_fn = preprocessing_factory.get_preprocessing(
    preprocessing_name, is_training=True)

# 打印参数
tf_utils.print_configuration(FLAGS.__flags, ssd_params,
                                dataset.data_sources, FLAGS.train_dir)

# DatasetDataProviders从数据集中提供数据. 通过配置,
# 可以同时使用多个readers或者使用单个reader提供数据。此外,被读取的数据
# 可以被打乱顺序
# 预加载
with tf.name_scope(FLAGS.dataset_name + '_data_provider'):
    provider = slim.dataset_data_provider.DatasetDataProvider(
        dataset,
        num_readers=FLAGS.num_readers,
        common_queue_capacity=20 * FLAGS.batch_size,
        common_queue_min=10 * FLAGS.batch_size,
        shuffle=True)
# 获取原始的图片和它对应的label,框ground truth的位置
[image, _, glabels, gbboxes] = provider.get(['image', 'shape',
                                                    'object/label',
                                                    'object/bbox'])

# 预处理图片标签和框的位置
image, glabels, gbboxes = \
    image_preprocessing_fn(image, glabels, gbboxes,
                            out_shape=ssd_shape,
                            data_format=DATA_FORMAT)

在这一部分中,可能存在的疑惑的是第二步和第五步,实际上第五步调用的就是第二步中的图像预处理函数,所以我们只要看懂第二步“获取数据集处理的函数“即可。

获得处理函数的代码是:

# 根据名字进行处理获得处理函数
image_preprocessing_fn = preprocessing_factory.get_preprocessing(
    preprocessing_name, is_training=True)

preprocessing_factory的文件夹内存放的都是图片处理的代码,在进入到get_preprocessing方法后,实际上会返回一个preprocessing_fn函数

该函数的作用实际上是返回ssd_vgg_preprocessing.preprocess_image处理后的结果。

ssd_vgg_preprocessing.preprocess_image实际上是preprocess_for_train处理后的结果。

preprocessing_factory的get_preprocessing代码如下:

def get_preprocessing(name, is_training=False):
    preprocessing_fn_map = {
        'ssd_300_vgg': ssd_vgg_preprocessing
    }

    if name not in preprocessing_fn_map:
        raise ValueError('Preprocessing name [%s] was not recognized' % name)

    def preprocessing_fn(image, labels, bboxes,
                         out_shape, data_format='NHWC', **kwargs):
        # 这里实际上调用ssd_vgg_preprocessing.preprocess_image
        return preprocessing_fn_map[name].preprocess_image(
            image, labels, bboxes, out_shape, data_format=data_format,
            is_training=is_training, **kwargs)
    return preprocessing_fn

ssd_vgg_preprocessing的preprocess_image代码如下:

def preprocess_image(image,
                     labels,
                     bboxes,
                     out_shape,
                     data_format,
                     is_training=False,
                     **kwargs):
    """Pre-process an given image.

    Args:
      image: A `Tensor` representing an image of arbitrary size.
      output_height: 预处理后图像的高度。
      output_width: 预处理后图像的宽度。
      is_training: 如果我们正在对图像进行预处理以进行训练,则为true;否则为false
      resize_side_min: 图像最小边的下界,用于保持方向的大小调整,
                如果“is_training”为“false”,则此值
                用于重新缩放
      resize_side_max: 图像最小边的上界,用于保持方向的大小调整
                如果“is_training”为“false”,则此值
                用于重新缩放
                the resize side is sampled from
                [resize_size_min, resize_size_max].

    Returns:
      预处理后的图片
    """
    if is_training:
        return preprocess_for_train(image, labels, bboxes,
                                    out_shape=out_shape,
                                    data_format=data_format)
    else:
        return preprocess_for_eval(image, labels, bboxes,
                                   out_shape=out_shape,
                                   data_format=data_format,
                                   **kwargs)

实际上最终是通过preprocess_for_train处理数据集。

preprocess_for_train处理的过程是:

1、改变数据类型。

2、样本框扭曲。

3、将图像大小调整为输出大小。

4、随机水平翻转图像。

5、随机扭曲颜色。有四种方法。

6、图像减去平均值

执行代码如下:

def preprocess_for_train(image, labels, bboxes,
                         out_shape, data_format='NHWC',
                         scope='ssd_preprocessing_train'):
    """Preprocesses the given image for training.

    Note that the actual resizing scale is sampled from
        [`resize_size_min`, `resize_size_max`].

    参数:
        image: 图片,任意size的图片.
        output_height: 处理后的图片高度.
        output_width: 处理后的图片宽度.
        resize_side_min: 图像最小边的下界,用于保方面调整大小
        resize_side_max: 图像最小边的上界,用于保方面调整大小
    Returns:
        处理过的图片
    """
    fast_mode = False
    with tf.name_scope(scope, 'ssd_preprocessing_train', [image, labels, bboxes]):
        if image.get_shape().ndims != 3:
            raise ValueError('Input must be of size [height, width, C>0]')
        # 改变图片的数据类型
        if image.dtype != tf.float32:
            image = tf.image.convert_image_dtype(image, dtype=tf.float32)
        # 样本框扭曲
        dst_image = image
        dst_image, labels, bboxes, _ = \
            distorted_bounding_box_crop(image, labels, bboxes,
                                        min_object_covered=MIN_OBJECT_COVERED,
                                        aspect_ratio_range=CROP_RATIO_RANGE)
        # 将图像大小调整为输出大小。
        dst_image = tf_image.resize_image(dst_image, out_shape,
                                          method=tf.image.ResizeMethod.BILINEAR,
                                          align_corners=False)

        # 随机水平翻转图像.
        dst_image, bboxes = tf_image.random_flip_left_right(dst_image, bboxes)

        # 随机扭曲颜色。有四种方法.
        dst_image = apply_with_random_selector(
                dst_image,
                lambda x, ordering: distort_color(x, ordering, fast_mode),
                num_cases=4)

        # 图像减去平均值
        image = dst_image * 255.
        image = tf_image_whitened(image, [_R_MEAN, _G_MEAN, _B_MEAN])
        # 图像的类型
        if data_format == 'NCHW':
            image = tf.transpose(image, perm=(2, 0, 1))
        return image, labels, bboxes

5、框的编码

该部分利用如下代码调用框的编码代码:

gclasses, glocalisations, gscores = ssd_net.bboxes_encode(glabels, gbboxes, ssd_anchors)

实际上bboxes_encode方法中,调用的是ssd_common模块中的tf_ssd_bboxes_encode。

def bboxes_encode(self, labels, bboxes, anchors,
                    scope=None):
    """
    进行编码操作
    """
    return ssd_common.tf_ssd_bboxes_encode(
        labels, bboxes, anchors,
        self.params.num_classes,
        self.params.no_annotation_label,
        ignore_threshold=0.5,
        prior_scaling=self.params.prior_scaling,
        scope=scope)

ssd_common.tf_ssd_bboxes_encode执行的代码是对特征层每一层进行编码操作。

def tf_ssd_bboxes_encode(labels,
                         bboxes,
                         anchors,
                         num_classes,
                         no_annotation_label,
                         ignore_threshold=0.5,
                         prior_scaling=[0.1, 0.1, 0.2, 0.2],
                         dtype=tf.float32,
                         scope='ssd_bboxes_encode'):
    """
      对每一个特征层进行解码
    """
    with tf.name_scope(scope):
        target_labels = []
        target_localizations = []
        target_scores = []
        for i, anchors_layer in enumerate(anchors):
            with tf.name_scope('bboxes_encode_block_%i' % i):
                t_labels, t_loc, t_scores = \
                    tf_ssd_bboxes_encode_layer(labels, bboxes, anchors_layer,
                                               num_classes, no_annotation_label,
                                               ignore_threshold,
                                               prior_scaling, dtype)
                target_labels.append(t_labels)
                target_localizations.append(t_loc)
                target_scores.append(t_scores)
        return target_labels, target_localizations, target_scores

实际上具体解码的操作在函数tf_ssd_bboxes_encode_layer里,tf_ssd_bboxes_encode_layer解码的思路是:

1、创建一系列变量用于存储编码结果。

    yref, xref, href, wref = anchors_layer

    ymin = yref - href / 2.
    xmin = xref - wref / 2.
    ymax = yref + href / 2.
    xmax = xref + wref / 2.
    vol_anchors = (xmax - xmin) * (ymax - ymin)
    # 1、创建一系列变量存储编码结果
    # 每个特征层的shape
    shape = (yref.shape[0], yref.shape[1], href.size)

    # 每个特征层特定点,特定框的label
    feat_labels = tf.zeros(shape, dtype=tf.int64)  # (m, m, k)
    # 每个特征层特定点,特定框的得分
    feat_scores = tf.zeros(shape, dtype=dtype)

    # 每个特征层特定点,特定框的位置
    feat_ymin = tf.zeros(shape, dtype=dtype)
    feat_xmin = tf.zeros(shape, dtype=dtype)
    feat_ymax = tf.ones(shape, dtype=dtype)
    feat_xmax = tf.ones(shape, dtype=dtype)

2、对所有的实际框都寻找其在特征层中对应的点与其对应的框,并将其标签找到。

    # 用于计算IOU
    def jaccard_with_anchors(bbox):

        int_ymin = tf.maximum(ymin, bbox[0])  # (m, m, k)
        int_xmin = tf.maximum(xmin, bbox[1])
        int_ymax = tf.minimum(ymax, bbox[2])
        int_xmax = tf.minimum(xmax, bbox[3])
        h = tf.maximum(int_ymax - int_ymin, 0.)
        w = tf.maximum(int_xmax - int_xmin, 0.)
        # Volumes.
        # 处理搜索框和bbox之间的联系
        inter_vol = h * w  # 交集面积
        union_vol = vol_anchors - inter_vol \
                    + (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])  # 并集面积
        jaccard = tf.div(inter_vol, union_vol)  # 交集/并集,即IOU
        return jaccard  # (m, m, k)

    def condition(i,feat_labels, feat_scores,
             feat_ymin, feat_xmin, feat_ymax, feat_xmax):

        r = tf.less(i, tf.shape(labels))
        return r[0]

    # 该部分用于寻找实际中的框对应特征层的哪个框
    def body(i, feat_labels, feat_scores,
             feat_ymin, feat_xmin, feat_ymax, feat_xmax):
        """
          更新功能标签、分数和bbox。
            -JacCard>0.5时赋值;
        """
        # 取出第i个标签和第i个bboxes
        label = labels[i]  # 当前图片上第i个对象的标签
        bbox = bboxes[i]  # 当前图片上第i个对象的真实框bbox

        # 计算该box和所有anchor_box的IOU
        jaccard = jaccard_with_anchors(bbox)  # 当前对象的bbox和当前层的搜索网格IOU

        # 所有高于历史的分的box被筛选
        mask = tf.greater(jaccard, feat_scores)  # 掩码矩阵,IOU大于历史得分的为True
        mask = tf.logical_and(mask, feat_scores > -0.5)

        imask = tf.cast(mask, tf.int64) #[1,0,1,1,0]
        fmask = tf.cast(mask, dtype)    #[1.,0.,1.,0. ... ]

        # Update values using mask.
        # 保证feat_labels存储对应位置得分最大对象标签,feat_scores存储那个得分
        # (m, m, k) × 当前类别 + (1 - (m, m, k)) × (m, m, k)
        # 更新label记录,此时的imask已经保证了True位置当前对像得分高于之前的对象得分,其他位置值不变

        # 将所有被认为是label的框的值赋予feat_labels
        feat_labels = imask * label + (1 - imask) * feat_labels
        # 用于寻找最匹配的框
        feat_scores = tf.where(mask, jaccard, feat_scores)

        # 下面四个矩阵存储对应label的真实框坐标
        # (m, m, k) × 当前框坐标scalar + (1 - (m, m, k)) × (m, m, k)
        feat_ymin = fmask * bbox[0] + (1 - fmask) * feat_ymin
        feat_xmin = fmask * bbox[1] + (1 - fmask) * feat_xmin
        feat_ymax = fmask * bbox[2] + (1 - fmask) * feat_ymax
        feat_xmax = fmask * bbox[3] + (1 - fmask) * feat_xmax

        return [i + 1, feat_labels, feat_scores,
                feat_ymin, feat_xmin, feat_ymax, feat_xmax]

    i = 0
    # 2、对所有的实际框都寻找其在特征层中对应的点与其对应的框,并将其标签找到。
    (i,feat_labels, feat_scores,feat_ymin, feat_xmin,
     feat_ymax, feat_xmax) = tf.while_loop(condition, body,
                                           [i,
                                            feat_labels, feat_scores,
                                            feat_ymin, feat_xmin,
                                            feat_ymax, feat_xmax])

3、转化成ssd中网络的输出格式。

    # Transform to center / size.
    # 3、转化成ssd中网络的输出格式。
    feat_cy = (feat_ymax + feat_ymin) / 2.
    feat_cx = (feat_xmax + feat_xmin) / 2.
    feat_h = feat_ymax - feat_ymin
    feat_w = feat_xmax - feat_xmin

    # Encode features.

    # 利用公式进行计算
    # 以搜索网格中心点为参考,真实框中心的偏移,单位长度为网格hw
    feat_cy = (feat_cy - yref) / href / prior_scaling[0]
    feat_cx = (feat_cx - xref) / wref / prior_scaling[1]
    # log((m, m, k) / (m, m, 1)) * 5
    # 真实框宽高/搜索网格宽高,取对
    feat_h = tf.log(feat_h / href) / prior_scaling[2]
    feat_w = tf.log(feat_w / wref) / prior_scaling[3]
    # Use SSD ordering: x / y / w / h instead of ours.(m, m, k, 4)
    feat_localizations = tf.stack([feat_cx, feat_cy, feat_w, feat_h], axis=-1)

    return feat_labels, feat_localizations, feat_scores

真实情况下的标签和框在编码完成后,格式与经过网络预测出的标签与框相同,此时才可以计算loss进行对比。

6、计算loss值

通过第五步获得的框的编码后的scores和locations指的是数据集标注的结果,是真实情况。
而计算loss值还需要预测情况。

通过如下代码可以获得每个image的预测情况将图片通过网络进行预测

# 设置SSD网络的参数
arg_scope = ssd_net.arg_scope(weight_decay=FLAGS.weight_decay,
                                data_format=DATA_FORMAT)

# 将图片经过网络获得它们的框的位置和prediction
with slim.arg_scope(arg_scope):
    _, localisations, logits, _ = \
        ssd_net.net(b_image, is_training=True)

调用loss计算函数计算三个loss值,分别对应正样本,负样本,定位。

# 计算loss值
n_positives_loss,n_negative_loss,localization_loss = ssd_net.losses(logits, localisations,
                                                        b_gclasses, b_glocalisations, b_gscores,
                                                        match_threshold=FLAGS.match_threshold,
                                                        negative_ratio=FLAGS.negative_ratio,
                                                        alpha=FLAGS.loss_alpha,
                                                        label_smoothing=FLAGS.label_smoothing)

# 会得到三个loss值,分别对应正样本,负样本,定位
loss_all = n_positives_loss + n_negative_loss + localization_loss

ssd_net.losses中,具体通过如下方式进行损失值的计算。

1、对所有的图片进行铺平,将其种类预测的转化为(?,num_classes),框预测的格式转化为(?,4),实际种类和实际得分的格式转化为(?),该步可以便于后面的比较与处理。最后将batch个图片平铺到同一表上。

2、在gscores中得到满足正样本得分的pmask正样本,不满足正样本得分的为nmask负样本,因为使用的是gscores,我们可以知道正样本负样本分类是针对真实值的。

3、将不满足正样本的位置设成对应prediction中背景的得分,其它设为1。

4、找到n_neg个最不可能为背景的点(实际上它是背景,这样利用二者计算的loss就很大)

5、分别计算正样本、负样本、框的位置的交叉熵。

def ssd_losses(logits, localisations,
               gclasses, glocalisations, gscores,
               match_threshold=0.5,
               negative_ratio=3.,
               alpha=1.,
               label_smoothing=0.,
               device='/cpu:0',
               scope=None):
    with tf.name_scope(scope, 'ssd_losses'):
        lshape = tfe.get_shape(logits[0], 5)
        num_classes = lshape[-1]
        batch_size = lshape[0]

        # 铺平所有vector
        flogits = []
        fgclasses = []
        fgscores = []
        flocalisations = []
        fglocalisations = []
        for i in range(len(logits)): # 按照图片循环
            flogits.append(tf.reshape(logits[i], [-1, num_classes]))
            fgclasses.append(tf.reshape(gclasses[i], [-1]))
            fgscores.append(tf.reshape(gscores[i], [-1]))
            flocalisations.append(tf.reshape(localisations[i], [-1, 4]))
            fglocalisations.append(tf.reshape(glocalisations[i], [-1, 4]))
        # 上一步所得的还存在batch个行里面,对应batch个图片
        # 这一步将batch个图片平铺到同一表上
        logits = tf.concat(flogits, axis=0)
        gclasses = tf.concat(fgclasses, axis=0)
        gscores = tf.concat(fgscores, axis=0)
        localisations = tf.concat(flocalisations, axis=0)
        glocalisations = tf.concat(fglocalisations, axis=0)
        dtype = logits.dtype

        # gscores中满足正样本得分的mask
        pmask = gscores > match_threshold
        fpmask = tf.cast(pmask, dtype)
        no_classes = tf.cast(pmask, tf.int32)

        nmask = tf.logical_and(tf.logical_not(pmask),# IOU达不到阈值的类别搜索框位置记1
                               gscores > -0.5)
        fnmask = tf.cast(nmask, dtype)

        n_positives = tf.reduce_sum(fpmask)
        # 将预测结果转化成比率
        predictions = slim.softmax(logits)
        nvalues = tf.where(nmask,
                           predictions[:, 0],   # 框内无物体标记为背景预测概率
                           1. - fnmask)         # 框内有物体位置标记为1
        nvalues_flat = tf.reshape(nvalues, [-1])

        # max_neg_entries为实际上负样本的个数
        max_neg_entries = tf.cast(tf.reduce_sum(fnmask), tf.int32)
        # n_neg为正样本的个数*3 + batch_size , 之所以+batchsize是因为每个图最少有一个负样本背景
        n_neg = tf.cast(negative_ratio * n_positives, tf.int32) + batch_size

        n_neg = tf.minimum(n_neg, max_neg_entries)

        # 找到n_neg个最不可能为背景的点
        val, idxes = tf.nn.top_k(-nvalues_flat, k=n_neg)
        max_hard_pred = -val[-1]
        # 在nmask找到n_neg个最不可能为背景的点(实际上它是背景,这样二者的差就很大)
        nmask = tf.logical_and(nmask, nvalues < max_hard_pred)
        fnmask = tf.cast(nmask, dtype)
        n_negative = tf.reduce_sum(fnmask)
        # 交叉熵
        with tf.name_scope('cross_entropy_pos'):
            loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
                                                                  labels=gclasses)
            n_positives_loss = tf.div(tf.reduce_sum(loss * fpmask), n_positives + 0.1, name='value')

        with tf.name_scope('cross_entropy_neg'):
            loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
                                                                  labels=no_classes)
            n_negative_loss = tf.div(tf.reduce_sum(loss * fnmask), n_negative + 0.1, name='value')

        # Add localization loss: smooth L1, L2, ...
        with tf.name_scope('localization'):
            # Weights Tensor: positive mask + random negative.
            weights = tf.expand_dims(alpha * fpmask, axis=-1)
            loss = custom_layers.abs_smooth(localisations - glocalisations)
            localization_loss = tf.div(tf.reduce_sum(loss * weights), n_positives + 0.1, name='value')

        return n_positives_loss,n_negative_loss,localization_loss

7、训练模型并保存

################################优化器设置##############################
learning_rate = tf_utils.configure_learning_rate(FLAGS,
                                                        dataset.num_samples,
                                                        global_step)

optimizer = tf_utils.configure_optimizer(FLAGS, learning_rate)

train_op = slim.learning.create_train_op(loss_all, optimizer,
                                        summarize_gradients=True)

#################################训练并保存模型###########################
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=FLAGS.gpu_memory_fraction)
config = tf.ConfigProto(log_device_placement=False,
                        gpu_options=gpu_options)
saver = tf.train.Saver(max_to_keep=5,
                        keep_checkpoint_every_n_hours=1.0,
                        write_version=2,
                        pad_step_number=False)

slim.learning.train(
    train_op,			# 优化器
    logdir=FLAGS.train_dir,		# 保存模型的地址
    master='',
    is_chief=True,
    init_fn=tf_utils.get_init_fn(FLAGS),	# 微调已存在模型时,初始化参数
    number_of_steps=FLAGS.max_number_of_steps,		# 最大步数
    log_every_n_steps=FLAGS.log_every_n_steps,		# 多少时间进行一次命令行输出
    save_summaries_secs=FLAGS.save_summaries_secs,	# 进行一次summary
    saver=saver,
    save_interval_secs=FLAGS.save_interval_secs,	# 多长时间保存一次模型
    session_config=config,
    sync_optimizer=None)

开始训练

在根目录下创建一个名为train.sh的文件。利用git上的bash执行命令行。

首先转到文件夹中。

cd D:/Collection/SSD-Retry

再执行train.sh文件。

bash train.sh

train.sh的代码如下:

DATASET_DIR=./tfrecords
TRAIN_DIR=./logs/
CHECKPOINT_PATH=./checkpoints/ssd_300_vgg.ckpt
python train_demo.py \
    --train_dir=${TRAIN_DIR} \
    --dataset_dir=${DATASET_DIR} \
    --dataset_name=pascalvoc_2012 \
    --dataset_split_name=train \
    --model_name=ssd_300_vgg \
    --checkpoint_path=${CHECKPOINT_PATH} \
    --save_summaries_secs=60 \
    --save_interval_secs=600 \
    --weight_decay=0.0005 \
    --optimizer=adam \
    --learning_rate=0.001 \
    --batch_size=8

训练效果:

以上就是python目标检测SSD算法训练部分源码详解的详细内容,更多关于python目标检测SSD算法训练的资料请关注我们其它相关文章!

(0)

相关推荐

  • python目标检测SSD算法训练部分源码详解

    目录 学习前言 讲解构架 模型训练的流程 1.设置参数 2.读取数据集 3.建立ssd网络. 4.预处理数据集 5.框的编码 6.计算loss值 7.训练模型并保存 开始训练 学习前言 ……又看了很久的SSD算法,今天讲解一下训练部分的代码.预测部分的代码可以参照https://blog.csdn.net/weixin_44791964/article/details/102496765 讲解构架 本次教程的讲解主要是对训练部分的代码进行讲解,该部分讲解主要是对训练函数的执行过程与执行思路进行详

  • python目标检测SSD算法预测部分源码详解

    目录 学习前言 什么是SSD算法 ssd_vgg_300主体的源码 学习前言 ……学习了很多有关目标检测的概念呀,咕噜咕噜,可是要怎么才能进行预测呢,我看了好久的SSD源码,将其中的预测部分提取了出来,训练部分我还没看懂 什么是SSD算法 SSD是一种非常优秀的one-stage方法,one-stage算法就是目标检测和分类是同时完成的,其主要思路是均匀地在图片的不同位置进行密集抽样,抽样时可以采用不同尺度和长宽比,然后利用CNN提取特征后直接进行分类与回归,整个过程只需要一步,所以其优势是速度

  • Android开发数据结构算法ArrayList源码详解

    目录 简介 ArrayList源码讲解 初始化 扩容 增加元素 一个元素 一堆元素 删除元素 一个元素 一堆元素 修改元素 查询元素 总结 ArrayList优点 ArrayList的缺点 简介 ArrayList是List接口的一个实现类,它是一个集合容器,我们通常会通过指定泛型来存储同一类数据,ArrayList默认容器大小为10,自身可以自动扩容,当容量不足时,扩大为原来的1.5倍,和上篇文章的Vector的最大区别应该就是线程安全了,ArrayList不能保证线程安全,但我们也可以通过其

  • Java并发编程之ConcurrentLinkedQueue源码详解

    一.ConcurrentLinkedQueue介绍 并编程中,一般需要用到安全的队列,如果要自己实现安全队列,可以使用2种方式: 方式1:加锁,这种实现方式就是我们常说的阻塞队列. 方式2:使用循环CAS算法实现,这种方式实现队列称之为非阻塞队列. 从点到面, 下面我们来看下非阻塞队列经典实现类:ConcurrentLinkedQueue (JDK1.8版) ConcurrentLinkedQueue 是一个基于链接节点的无界线程安全的队列.当我们添加一个元素的时候,它会添加到队列的尾部,当我们

  • Android实现屏幕锁定源码详解

    最近有朋友问屏幕锁定的问题,自己也在学习,网上找了下也没太详细的例子,看的资料书上也没有有关屏幕锁定程序的介绍,下个小决心,自己照着官方文档学习下,现在做好了,废话不多说,先发下截图,看下效果,需要注意的地方会加注释,有问题的朋友可以直接留言,我们共同学习交流,共同提高进步!直接看效果图: 一:未设置密码时进入系统设置的效果图如下: 二:设置密码方式预览: 三:密码解密效果图 四:九宫格解密时的效果图 下面来简单的看下源码吧,此处讲下,这个小DEMO也是临时学习下的,有讲的不明白的地方请朋友直接

  • Spring AOP底层源码详解

    ProxyFactory的工作原理 ProxyFactory是一个代理对象生产工厂,在生成代理对象之前需要对代理工厂进行配置.ProxyFactory在生成代理对象之前需要决定到底是使用JDK动态代理还是CGLIB技术. // config就是ProxyFactory对象 // optimize为true,或proxyTargetClass为true,或用户没有给ProxyFactory对象添加interface if (config.isOptimize() || config.isProxy

  • Java8中AbstractExecutorService与FutureTask源码详解

    目录 前言 一.AbstractExecutorService 1.定义 2.submit 3.invokeAll 4.invokeAny 二.FutureTask 1.定义 2.构造方法 3.get 4.run/ runAndReset 5. cancel 三.ExecutorCompletionService 1.定义 2.submit 3.take/ poll 总结 前言 本篇博客重点讲解ThreadPoolExecutor的三个基础设施类AbstractExecutorService.F

  • Django Rest Framework实现身份认证源码详解

    目录 一.Django框架 二.身份认证的两种实现方式: 三.身份认证源码解析流程 一.Django框架 Django确实是一个很强大,用起来很爽的一个框架,在Rest Framework中已经将身份认证全都封装好了,用的时候直接导入authentication.py这个模块就好了.这个模块中5个认证类.但是我们在开发中很少用自带的认证类,而是根据项目实际需要去自己实现认证类.下面是内置的认证类 BaseAuthentication(object):所有的认证相关的类都继承自这个类,我们写的认证

  • Android线程间通信Handler源码详解

    目录 前言 01. 用法 02.源码 03.结语 前言 在[Android]线程间通信 - Handler之使用篇主要讲了 Handler 的创建,发送消息,处理消息 三个步骤.那么接下来,我们也按照这三个步骤,从源码中去探析一下它们具体是如何实现的.本篇是关于创建源码的分析. 01. 用法 先回顾一下,在主线程和非主线程是如何创建 Handler 的. //主线程 private val mHandler: Handler = object : Handler(Looper.getMainLo

  • Spring JPA联表查询之OneToOne源码详解

    目录 前言 源码 注解属性 单向联表 user 实体类 car 实体类 查询结果 双向联表 user 实体 car 实体 查询结果 延迟加载(懒加载) user 实体 查询结果: 查询完会发现,控制台又打印了一个 JPQL: 最后结论 前言 前面几篇我们学习的都是单表查询,就是对一张表中的数据进行查询.而实际项目中,基本都会有多张表联合查询的情况,今天我们就来了解下JPA的联表查询是如做的. 源码 @OneToOne 注解实现一对一关系映射.比如用户跟车辆的关系(这里假设一个人只能有一辆车),一

随机推荐