模型训练之.record文件生成.ckpt文件
此文基于:win10 Tensorflow Object Detection API配置
修改配置文件
下载预训练模型提取其中的相关文件稍加修改作为我们自己的文件
预训练模型可以从以下两个地方获取:
1. https://github.com/tensorflow/models/tree/master/research/slim
我这里随便选一个下载ssd_mobilenet_v1_coco
注:型号名称末尾的星号(☆)表示此型号支持TPU训训练。
解压后得到的文件夹“ssd_mobilenet_v1_coco_2018_01_28”全部复制到“F:\labelImg”目录下
接下来需要设置配置文件, 进入 Object Detection github 对应页面 寻找 配置文件的Sample
把下载下来的“ssd_mobilenet_v1_coco.config”文件复制到“F:\labelImg\tfrecords\hand”并对以下的地方做出修改:
第一处:
1 2 |
ssd { num_classes: 1 # 修改为自己检测的目标的类别数, 例如我的是一个 |
第二处:
在原有的模型上加入自己的模型进行微调:设置fine_tune_checkpoint路径为已训练的模型model.ckpt。同时from_detection_checkpoint为true。
全新训练自己的模型:注释掉fine_tune_checkpoint且from_detection_checkpoint设置为false
1 2 3 |
#fine_tune_checkpoint: "F:\\labelImg\\ssd_mobilenet_v1_coco_2018_01_28\\model.ckpt" #from_detection_checkpoint: true from_detection_checkpoint: false |
第三处:
训练集及pbtxt路径
1 2 3 4 5 6 |
train_input_reader: { tf_record_input_reader { input_path: "F:\\labelImg\\tfrecords\\hand\\train.record" } label_map_path: "F:\\labelImg\\tfrecords\\hand\\label_map.pbtxt" } |
第4处:
测试集及pbtxt路径
1 2 3 4 5 |
eval_input_reader: { tf_record_input_reader { input_path: "F:\\labelImg\\tfrecords\\hand\\eval.record" } label_map_path: "F:\\labelImg\\tfrecords\\hand\\label_map.pbtxt" |
开始训练:方法有二
方法一:
利用“models-master\research\object_detection”目录下的model_main.py文件进行训练,
运行:
1 |
python object_detection/model_main.py --pipeline_config_path=F:\\labelImg\\tfrecords\\hand\\ssd_mobilenet_v1_coco.config --model_dir=F:\\labelImg\\tfrecords\\hand\\pb --num_train_steps=50000 --num_eval_steps=2000 --alsologtostderr |
1 2 3 4 |
#找到此行代码 results.dataset['categories'] = copy.deepcopy(self.dataset['categories']) #修改为 results.dataset['categories'] = copy.deepcopy(list(self.dataset['categories'])) |
利用“models-master\research\object_detection\legacy”目录下的train.py文件进行训练,
运行:
1 |
python legacy/train.py --logtostderr --train_dir=F:\labelImg\tfrecords\hand\ckpt --pipeline_config_path=F:\labelImg\tfrecords\hand\ssd_mobilenet_v1_coco.config |
成功后在“pb”目录下生成以下文件:
等待loss稳定在一个比较小的值之间,则可以停止训练。(直接关闭窗口以上即可)