DataTransformer类用于输入图片的预处理,在LayerParameter中包含了相关字段,但只在数据读取层中被使用。

文件

include/caffe/data_transformer.hpp
src/caffe/data_transformer.cpp

依赖

  1. Datum类用于存储图片的序列化表示,用于lmdb/leveldb的数据读取和保存
    // lmdb/leveldb中的图片存储格式,一张图片一个Datum
    message Datum {
    optional int32 channels = 1;                  // 通道数 
    optional int32 height = 2;                    // 高度
    optional int32 width = 3;                     // 宽度
    optional bytes data = 4;                      // 实际的数据,以字节数组格式存储
    optional int32 label = 5;                     // 实际的label,32位整数
    repeated float float_data = 6;                // 也可以保存float格式的数据
    optional bool encoded = 7 [default = false];  // data/float_data是否时编码的数据
    }
    
  2. 数据转换的参数出现在LayerParameter中,使用转换参数类TransformationParameter
message LayerParameter {
  ...
  optional TransformationParameter transform_param = 100;   // 数据读取层的预处理参数类
  ...
}
message TransformationParameter {                   // 数据读取层的预处理参数类
  optional float scale = 1 [default = 1];           // 元素值按比例增减
  optional bool mirror = 2 [default = false];       // 是否作随机镜像变换
  optional uint32 crop_size = 3 [default = 0];      // 随机位置截取crop_size*crop_size大小的子图,默认用全图
  optional string mean_file = 4;                    // 均值所在文件
  repeated float mean_value = 5;                    // 均值,不能和mean_file同时出现。可以为1个,或channels个
  optional bool force_color = 6 [default = false];  // 是否强制生成3通道数据
  optional bool force_gray = 7 [default = false];   // 是否强制生成1通道数据
}

成员

// 输入图片的预处理类
template <typename Dtype>
class DataTransformer {
 public:
  explicit DataTransformer(const TransformationParameter& param, Phase phase);  // Train和Test的预处理稍有不同
  virtual ~DataTransformer() {}

  // 不同类型源数据的预处理函数,最后都写到blob中。目标blob可以共享top blob的data数据,这样可以减少数据拷贝
  void Transform(const Datum& datum, Blob<Dtype>* transformed_blob);
  void Transform(const vector<Datum> & datum_vector, Blob<Dtype>* transformed_blob);
#ifdef USE_OPENCV
  void Transform(const vector<cv::Mat> & mat_vector, Blob<Dtype>* transformed_blob);
  void Transform(const cv::Mat& cv_img, Blob<Dtype>* transformed_blob);
#endif  // USE_OPENCV
  void Transform(Blob<Dtype>* input_blob, Blob<Dtype>* transformed_blob);

  // 推断预处理后的blob的维度,一张图片时,num维度为1
  vector<int> InferBlobShape(const Datum& datum);
  vector<int> InferBlobShape(const vector<Datum> & datum_vector);
#ifdef USE_OPENCV
  vector<int> InferBlobShape(const vector<cv::Mat> & mat_vector);
  vector<int> InferBlobShape(const cv::Mat& cv_img);
#endif  // USE_OPENCV

  // 初始化随机数产生器
  void InitRand();
 protected:
  // 随机数产生器:返回[0, n-1]的随机数
  virtual int Rand(int n);
  void Transform(const Datum& datum, Dtype* transformed_data);

  TransformationParameter param_;               // 预处理参数
  shared_ptr<Caffe::RNG> rng_;                  // 内部随机数产生器
  Phase phase_;                                 // 网络状态
  Blob<Dtype> data_mean_;                       // 从mean_file中读取的均值blob
  vector<Dtype> mean_values_;                   // 均值数组,可以只有一个或通道数个
};

results matching ""

    No results matching ""