实现text-detection-ctpn一路的坎坎坷坷

 小编在学习文字检测,因为作者提供的caffe实现没有训练代码(不过训练代码可以参考faster-rcnn的训练代码),所以我打算先使用tensorflow实现,主要是复现前辈的代码,主要是对文字检测模型进行训练。   代码的GitHub地址:https://github.com/eragonruan/text-detection-ctpn   主要写一下自己实现的过程,因为原文给的步骤,小编没有完全实现,所以首先打算解读一下原文步骤,然后加上自己的理解,写下自己可以实现的步骤。 文本检测概述   文本检测可以看成特殊的目标检测,但是它有别与通过目标检测,在通用目标检测中,每个目标都有定义好的边界框,检测出的bbox与当前目标的groundtruth重叠率大于0.5就表示该检测结果正确,文本检测中正确检出需要覆盖整个文本长度,且评判的标准不同于通用目标检测,具体的评判方法参见(ICDAR 2017 RobustReading Competition).所以通用的目标检测方法并不适用文本检测。 1,参数设置 parameters there are some parameters you may need to modify according to your requirement, you can find them in ctpn/text.yml USE_GPU_NMS # whether to use nms implemented in cuda or not DETECT_MODE # H represents horizontal mode, O represents oriented mode, default is H checkpoints_path # the model I provided is in checkpoints/, if you train the model by yourself,it will be saved in output/ 1.1 对其进行翻译如下:   根据我们的一些要求,我们可能需要修改一些参数,文件在ctpn/text.yml USE_GPU_NMS 是否使用在cuda中实现的nms DETECT_MODE H表示水平模式,O表示定向模式,默认为H checkpoints_path 作者提供的模型在checkpoints/ 如果我们自己训练模型,它将保存在 output/ 下面 自己训练的模型在这个路径下面: 1 checkpoints_path: output/ctpn_end2end/voc_2007_trainval 下面展示一下小编训练出来的模型: 2:环境设置 setup requirements: python2.7, tensorflow1.3, cython0.24, opencv-python, easydict,(recommend to install Anaconda) if you have a gpu device, build the library by 1 2 3 cd lib / utils chmod + x make.sh ./make.sh 2.1 对其进行翻译如下:   需求的是python2.7 tensorflow1.3 cython0.24,opencv-python,easydict,(建议安装Anaconda)   (因为我有GPU)所以直接进行第三步,进入lib、utils,执行chmod+x给权限(在给权限之前,make.sh是灰色的(不可执行的文件),执行chmod+x make.sh 则变成绿色(可执行的文件)) 3:准备数据 prepare data First, download the pre-trained model of VGG net and put it in data/pretrain/VGG_imagenet.npy. you can download it from google drive or baidu yun. Second, prepare the training data as referred in paper, or you can download the data I prepared from google drive or baidu yun. Or you can prepare your own data according to the following steps. Modify the path and gt_path in prepare_training_data/split_label.py according to your dataset. And run 1 2 cd lib/prepare_training_data python split_label.py it will generate the prepared data in current folder, and then run 1 python ToVoc.py to convert the prepared training data into voc format. It will generate a folder named TEXTVOC. move this folder to data/ and then run 1 2 cd ../../data ln -s TEXTVOC VOCdevkit2007 3.1 对其进行翻译   首先,下载预先训练的VGG网络模型并将其放在data/pretrain/VGG_imagenet.npy.   其次,准备论文提到的训练数据。或者我们可以放置自己的数据   根据我们的数据集修改prepare_training_data/split_label.py中的path和gt_path路径。并执行下面操作。 1 2 cd lib/prepare_training_data python split_label.py   它将在当前文件夹中生成准备好的数据,然后运行下面代码: 1 python ToVoc.py 将准备好的训练数据转换为voc格式。它将生成一个名为TEXTVOC的文件夹。将此文件夹移动到数据/然后运行 1 2 cd ../../data ln -s TEXTVOC VOCdevkit2007 3.2 数据是否只有VOC2007?   作者给的数据是预处理过的数据,   我们下载了数据,VOCdevkit2007 只有1.06G,但是此数据可以训练自己的模式,要是想训练自己的数据,那么需要自己标注数据,找自己的数据。   作者使用的icdar17的multi lingual scene text dataset, 没有用voc,只是用了他的数据格式,下面给出的数据是作者实现的源数据地址。   gt_path的数据地址:http://rrc.cvc.uab.es/?com=contestant   进入2017MLT 查看如下:   然后我们可以发送邮件,注册用户,并激活,进入下载页面:   找到数据集并下载,因为这是国外网址,所以被墙了,小编没有全部下载下来,就走到了这一步,目前没有下一步(如果有人看到这篇博文,希望把下载的数据能分享给我,先在这里道声谢!!!): 3.3 存放数据   作者训练使用的是6000张图片。使用train或者trainval是一样的,因为用的都是这6000张图片。可以检查一下VOCdevkit2007/VOC2007/ImageSets/Main下面的train.txt和trainval.txt是否正确,是否是6000张图片。你在用自己数据训练的时候也要特别注意一点,数据的标注格式是不是和mlt这个数据集一致,因为split_label这个函数是针对mlt的标注格式来写的,所以如果你原始数据标注格式如果和它不同,转换之后可能会是错的,那么得到的用来训练的数据集可能也不对。   这是作者存放数据的路径,我们修改路径,并放数据(因为源数据没有拿到,所以就数据存放也就做到这一步,没有后续!!)。 对原始gt文件进一步处理的分析(也就是对txt标注数据进行进一步处理),生成对应的xml文件部分内容截图如下: 对split_label的部分代码截取如下: + View Code 3.4 参考知乎大神的准备数据如下:   数据标注   在标注数据的时候采用的是顺时针方向,一次是左上角坐标点,右上角坐标点,右下角坐标点,左下角坐标点(即x1,y1,x2,y2,x3,y3,x4,y4),,这里的标注方式与通用目标检测的目标检测方式一样,这里我标注的数据是生成到txt中,具体格式如下:   x1,y1,x2,y2,x3,y3,x4,y4 分别是一个框的四个角点的x,y坐标。这是因为作者用的mlt训练的,他的数据就是这么标注的,如果你要用一些水平文本的数据集,标注是x,y,w,h的,也是可以的,修改一下split_label的代码,或者写个小脚本把x,y,w,h转换成x1,y1,x2,y2,x3,y3,x4,y4就行。   数据处理   根据ctpn训练数据的要求,需要对上述数据(txt标注数据)进行进一步的处理,生成对应的xml文件,具体格式参考pascal voc 具体的训练数据截图和生成的pascal voc格式如下图:   处理数据的时候执行下面代码(和原文一致) 1 2 3 4 5 cd lib/prepare_training_data python split_label.py python ToVoc.py cd ../../data ln -s TEXTVOC VOCdevkit2007   注意:这里生成的数据会在当前目录下,文件夹为TEXTVOC,需要将该文件夹移至/data目录下,然后再做VOCdevikt2007的软连接。 3.5 准备数据注意事项   在原作者使用那6000张图片的话,roidb和image_index都是6000,因为使用的train和trainval是一样的,所以我们在使用自己数据训练的时候也要特别注意一点,数据的标注格式是不是与mlt这个数据集一致,因为split_label这个函数是针对mlt的标注格式来写的,所以我们原始数据标注格式如果和它不同,转化之后可能会是错的,那么得来的用来训练的数据集可能也不对。   cache是为了加速数据读取,所以不会每次重新生成,更换了数据集需要手动清理。 3.6 训练数据的格式是什么样子,是否需要准备图片?   其实想了解自己准备图片的格式,以及图片中的文字区域的坐标是否需要手动标出,才能训练。   上面也说了训练数据的格式是x1,y1,x2,y2,x3,y3,x4,y4 ,当然了自己标注比较麻烦,这里我们可以直接使用一些公开的数据集,原作者使用的额是multi lingual scene texts dataset。 4:训练 Simplely run 1 python ./ctpn/train_net.py you can modify some hyper parameters in ctpn/text.yml, or just used the parameters I set. The model I provided in checkpoints is trained on GTX1070 for 50k iters. If you are using cuda nms, it takes about 0.2s per iter. So it will takes about 2.5 hours to finished 50k iterations. 4.1:对其进行翻译 简单的运行   你可以在ctpn/text.yml中修改一些参数,或者只使用作者设置的参数   作者提供的模型在GTX1070上训练了50K iters   如果我们正在使用cuda nms ,它每次约需要0.2秒,因此完成50k迭代需要大约2.5小时 当然,我们可以指定在那块显卡上运行,比如我这里指定选择第一块显卡上训练,训练的命令如下: 1 CUDA_VISIBLE_DEVICES="0" python ./ctpn/train_net.py 4.2 成功运行截图!!! 4.3:执行训练代码报的一个错误如下 1 AttributeError: module 'tensorflow.python.ops.gen_logging_ops' has no attribute '_image_summary'   tensroflow 新版本相较于一些老版本更改了一些函数和变量类型。可以到 \lib\fast_rcnn\train.py 内尝试把 build_image_summary(self) 函数整体替换为以下语句: 1 2 3 4 5 6 7 8 9 10 def build_image_summary(self): # A simple graph for write image summary log_image_data = tf.placeholder(tf.uint8, [None, None, 3]) log_image_name = tf.placeholder(tf.string) from tensorflow.python.ops import gen_logging_ops from tensorflow.python.framework import ops as _ops log_image = tf.summary.image(str(log_image_name), tf.expand_dims(log_image_data, 0), max_outputs=1) _ops.add_to_collection(_ops.GraphKeys.SUMMARIES, log_image) return log_image, log_image_data, log_image_name   也就是把原文中那句替换成下面这句: 1 2 log_image = tf.summary.image(str(log_image_name), tf.expand_dims(log_image_data, 0), max_outputs=1) 4.4 在训练时候,训练集扩展了2倍,目的是什么?   在训练时候,训练集扩展了2倍,图片倍翻转了,这样做的目的是扩展训练集。 5:部分代码解析 5.1 train_net.py的代码解析 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 import os.path import pprint import sys #os.getcwd()返回当前工作目录 sys.path.append()用于将前面的工作目录添加到搜索路径中 sys.path.append(os.getcwd()) from lib.fast_rcnn.train import get_training_roidb, train_net from lib.fast_rcnn.config import cfg_from_file, get_output_dir, get_log_dir from lib.datasets.factory import get_imdb from lib.networks.factory import get_network from lib.fast_rcnn.config import cfg if __name__ == '__main__': #存放训练参数 cfg_from_file('ctpn/text.yml') print('Using config:') # pprint函数的pprint模块下的方法是一种标准的格式化输出方式。 # pprint(object, stream=None, indent=1, width=80, depth=None, *, compact=False) # 这里是将训练的参数格式化显示出来 pprint.pprint(cfg) # 读取VOC中的数据集 imdb = get_imdb('voc_2007_trainval') print('Loaded dataset `{:s}` for training'.format(imdb.name)) # 获得感兴趣区域的数据集 roidb = get_training_roidb(imdb) # 返回程序运行结果存放的文件夹的路径 output_dir = get_output_dir(imdb, None) # 返回程序运行时中间过程产生的文件。 log_dir = get_log_dir(imdb) print('Output will be saved to `{:s}`'.format(output_dir)) print('Logs will be saved to `{:s}`'.format(log_dir)) device_name = '/gpu:0' print(device_name) # 获取VGG网络结构 定义网络结构 network = get_network('VGGnet_train') train_net(network, imdb, roidb, output_dir=output_dir, log_dir=log_dir, pretrained_model='data/pretrain/VGG_imagenet.npy', max_iters=int(cfg.TRAIN.max_steps),restore=bool(int(cfg.TRAIN.restore))) #采用VGG_Net 输入训练图片的数据集,感兴趣区域的数据集等开始训练。。 参考文献:https://zhuanlan.zhihu.com/p/37363942 http://slade-ruan.me/2017/10/22/text-detection-ctpn/ 不经一番彻骨寒 怎得梅花扑鼻香 好文要顶 关注我 收藏该文 战争热诚 关注 - 19 粉丝 - 156 +加关注 0 0 « 上一篇:深度学习论文翻译解析(四):Faster R-CNN: Down the rabbit hole of modern object detection posted @ 2018-12-05 10:30 战争热诚 阅读(22) 评论(0) 编辑 收藏 刷新评论刷新页面返回顶部 注册用户登录后才能发表评论,请 登录 或 注册,访问网站首页。 【推荐】超50万VC++源码: 大型组态工控、电力仿真CAD与GIS源码库! 【福利】华为云4核8G云主机免费试用 【活动】申请成为华为云云享专家 尊享9大权益 【活动】腾讯云+社区开发者大会12月15日首都北京盛大起航! 腾讯云1129 相关博文: · 坎坎坷坷人生路 · 一路走来一路歌! · 研路——一路追寻,一路期盼 · Caffe训练源码基本流程 · 关于小蜘蛛诞生的坎坎坷坷 最新新闻: · Uber聘请两名新的健康高管 大举推进医疗运输业务 · 商业权威跌落神坛,创始人崇拜毁于2018? · 百度宣布发行2.5亿美元债券 · 消息称软银愿景基金明年将在上海开设首个中国办事处 · 能使病毒"隐身"的基因:超过阈值才会被免疫系统发现 » 更多新闻... 昵称:战争热诚 园龄:1年3个月 粉丝:156 关注:19 +加关注 < 2018年12月 > 日 一 二 三 四 五 六 25 26 27 28 29 30 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 1 2 3 4 5 搜索 我的标签 前端开发基础知识(26) 机器学习常用算法及笔记(22) python 算法与面试笔试题(18) 数据库基础知识及其笔试题(18) python 项目及规范要求(13) Linux基础操作知识点(12) 计算机网络基础知识(8) 深度学习论文翻译解析(5) Git(4) 机器学习进阶之路(1) 更多 随笔档案 2018年12月 (2) 2018年11月 (4) 2018年10月 (8) 2018年9月 (6) 2018年8月 (9) 2018年7月 (1) 2018年6月 (6) 2018年5月 (12) 2018年4月 (11) 2018年3月 (11) 2018年2月 (8) 2018年1月 (11) 2017年12月 (11) 2017年11月 (5) 2017年9月 (9) 2017年8月 (5) 最新评论 1. Re:不得不了解的机器学习面试知识点 @你知道所有的未来这样啊 不好意思,是我激进了,当然可以!!... --战争热诚 2. Re:不得不了解的机器学习面试知识点 @战争热诚你可能是误会我的意思了,我的意思我想转发一下你写的这一篇文章,然后在文章里面留你的原文地址... --你知道所有的未来 3. Re:不得不了解的机器学习面试知识点 @你知道所有的未来请您看清楚,这是转发的吗?,转发的我一定会注明作者,但是不好意思,是我自己网上找的题,自己整理的结果,谢谢!!!... --战争热诚 4. Re:不得不了解的机器学习面试知识点 作者转发一下,留原文地址,可以嘛 --你知道所有的未来 5. Re:深度学习论文翻译解析(四):Faster R-CNN: Down the rabbit hole of modern object detection @少农丈哈哈哈,不客气,加油!!!... --战争热诚 阅读排行榜 1. Git安装教程(windows)(33994) 2. python 生成器和迭代器有这篇就够了(13584) 3. 浅谈使用git进行版本控制(7751) 4. python 一篇搞定所有的异常处理(4458) 5. Python 浅析线程(threading模块)和进程(process)(4044) 6. python 常用算法学习(1)(3564) 7. 如何为开发项目编写规范的README文件(windows),此文详解(3539) 8. 战争热诚的python全栈开发之路(2134) 9. 浅析文本挖掘(jieba模块的应用)(2054) 10. 深入学习使用ocr算法识别图片中文字的方法(1746) 评论排行榜 1. 战争热诚的python全栈开发之路(19) 2. 深入学习卷积神经网络中卷积层和池化层的意义(8) 3. 深入学习图像处理——图像相似度算法(5) 4. 记录自己使用GitHub的点点滴滴(4) 5. 不得不了解的机器学习面试知识点(4) 推荐排行榜 1. 战争热诚的python全栈开发之路(10) 2. 网络基础知识-网络协议(8) 3. python 生成器和迭代器有这篇就够了(8) 4. 记录自己使用GitHub的点点滴滴(7) 5. 如何为开发项目编写规范的README文件(windows),此文详解(6) 6. MySQL 进阶之索引(5) 7. 浅谈使用git进行版本控制(5) 8. 深入学习卷积神经网络中卷积层和池化层的意义(5) 9. python 闯关之路三(面向对象与网络编程)(4) 10. python 面向对象之多态与绑定方法https://www.cnblogs.com/wj-1314/p/9952868.html
50000+
5万行代码练就真实本领
17年
创办于2008年老牌培训机构
1000+
合作企业
98%
就业率

联系我们

电话咨询

0532-85025005

扫码添加微信