-
Notifications
You must be signed in to change notification settings - Fork 4
Reinforcement Training
limiao edited this page Dec 14, 2017
·
5 revisions
本页主要介绍一下Reinforcement Training的使用方法。总体基于文章"Self-critical Sequence Training for Image Captioning" 的方法,使用cider作为reward,进行policy gradient训练。 (https://arxiv.org/pdf/1612.00563.pdf)
目前代码在rl_training分支。
cider的计算基于鹤达写的in graph的计算方法,具体代码在im2txt/tf_cider.py里,计算过程需要ngram的document frequency,需要预先计算,运行 scripts/build_document_frequency.sh 即可。
目前支持的模型只有show_and_tell_in_graph_model 和 show_and_tell_advanced_model, 如果要支持其他模型,需要在原有代码上稍微修改一下,可以参考我改动的部分,主要就是用lstm_cell 进行greedy和sample两次decode,将两次decode的结果和长度输出。
使用方法与之前的finetune类似,需要以之前cross entropy训练好的模型基础上训练,可以参考这个脚本advanced-ss_att-finetune-with-decay-da-rl.sh,主要要加入rl_train(表示开启RL训练模式)和document_frequency_file(载入用于计算cider的document frequency)两个参数。这个脚本的设定里,没有finetune CNN, 学习率是随便设的,大概是0.01左右,可以调整。
- 在训练过程中,我感觉loss这个参数对于学习结果的好坏参考意义不大,我在summary里加入了sample_reward和greedy_reward(一个batch里的平均cider得分),可能参考性更好,如果一直在提升,那么是比较好的。