tensorflow训练模型配置文件详解
2019年1月4日
参考官方:Configuring the Object Detection Training Pipeline
Tensorflow目标检测API使用protobuf文件来配置训练和测试相关的东西。
文件格式可以参照object_detection/protos/pipeline.proto文件:
1 2 3 4 5 |
import "object_detection/protos/eval.proto"; import "object_detection/protos/graph_rewriter.proto"; import "object_detection/protos/input_reader.proto"; import "object_detection/protos/model.proto"; import "object_detection/protos/train.proto"; |
- model_config。这里定义训练模型类型和使用参数。
model_config.proto
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
syntax = <span class="hljs-string">"proto2"</span>; <span class="hljs-keyword">package</span> object_detection.protos; <span class="hljs-keyword">import</span> <span class="hljs-string">"object_detection/protos/faster_rcnn.proto"</span>; <span class="hljs-keyword">import</span> <span class="hljs-string">"object_detection/protos/ssd.proto"</span>; <span class="hljs-comment">// Top level configuration for DetectionModels.</span> message DetectionModel { oneof model { FasterRcnn faster_rcnn = <span class="hljs-number">1</span>; Ssd ssd = <span class="hljs-number">2</span>; } } |
- train_config,定义哪些参数应该被用来训练模型参数。
train_config.proto
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
syntax = <span class="hljs-string">"proto2"</span>; <span class="hljs-keyword">package</span> object_detection.protos; <span class="hljs-keyword">import</span> <span class="hljs-string">"object_detection/protos/optimizer.proto"</span>; <span class="hljs-keyword">import</span> <span class="hljs-string">"object_detection/protos/preprocessor.proto"</span>; <span class="hljs-comment">// Message for configuring DetectionModel training jobs (train.py).</span> message TrainConfig { <span class="hljs-comment">// Input queue batch size.</span> optional uint32 batch_size = <span class="hljs-number">1</span> [<span class="hljs-keyword">default</span>=<span class="hljs-number">32</span>]; <span class="hljs-comment">// Data augmentation options.</span> repeated PreprocessingStep data_augmentation_options = <span class="hljs-number">2</span>; |
- eval_config,设定的使用哪些指标将被报告进行评估。
1 2 3 4 5 6 7 8 9 10 11 12 |
syntax = <span class="hljs-string">"proto2"</span>; <span class="hljs-keyword">package</span> object_detection.protos; <span class="hljs-comment">// Message for configuring DetectionModel evaluation jobs (eval.py).</span> message EvalConfig { <span class="hljs-comment">// Number of visualization images to generate.</span> optional uint32 num_visualizations = <span class="hljs-number">1</span> [<span class="hljs-keyword">default</span>=<span class="hljs-number">10</span>]; <span class="hljs-comment">// Number of examples to process of evaluation.</span> optional uint32 num_examples = <span class="hljs-number">2</span> [<span class="hljs-keyword">default</span>=<span class="hljs-number">5000</span>]; |
- train_input_config,定义在哪些数据集是训练集。
- eval_input_config,定义在哪些数据集进行评估的。通常,这应该与训练输入数据集不同。
配置文件分为5个部分:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
model { (… 在这里添加model的配置…) } train_config : { (… Add train_config here…) } train_input_reader: { (… Add train_input configuration here…) } eval_config: { } eval_input_reader: { (… Add eval_input configuration here…) } |