前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >MATLAB强化学习toolbox

MATLAB强化学习toolbox

作者头像
万木逢春
发布2019-10-18 14:58:56
3.4K3
发布2019-10-18 14:58:56
举报
文章被收录于专栏:帮你学MatLab

新版本MATLAB提供了Reinforcement Learning Toolbox可以方便地建立二维基础网格环境、设置起点、目标、障碍,以及各种agent模型

这是Q-learning的训练简单实现

ccc

%% 布置环境硬件

GW = createGridWorld(6,6);

GW.CurrentState = '[6,1]';

GW.TerminalStates = '[2,5]';

GW.ObstacleStates = ["[2,3]";"[2,4]";"[3,5]";"[4,5]"];

%% 根据障碍设置可否行进

updateStateTranstionForObstacles(GW)

%% 设置reward

nS = numel(GW.States);

nA = numel(GW.Actions);

GW.R = -1*ones(nS,nS,nA);

GW.R(:,state2idx(GW,GW.TerminalStates),:) = 10;

%% 生成环境及初始位置

env = rlMDPEnv(GW);

plot(env)

env.ResetFcn = @() 6;

%% Q-learning训练参数初始化

qTable = rlTable(getObservationInfo(env),getActionInfo(env));

tableRep = rlRepresentation(qTable);

tableRep.Options.LearnRate = 1;

agentOpts = rlQAgentOptions;

agentOpts.EpsilonGreedyExploration.Epsilon = .04;

qAgent = rlQAgent(tableRep,agentOpts);

trainOpts = rlTrainingOptions;

trainOpts.MaxStepsPerEpisode = 50;

trainOpts.MaxEpisodes= 200;

trainOpts.StopTrainingCriteria = "AverageReward";

trainOpts.StopTrainingValue = 11;

trainOpts.ScoreAveragingWindowLength = 30;

%% 训练

rng(0)

trainingStats = train(qAgent,env,trainOpts);

%% 结果展示

plot(env)

env.Model.Viewer.ShowTrace = true;

env.Model.Viewer.clearTrace;

sim(qAgent,env)

这是SARSA的训练简单实现

ccc

%% 布置环境硬件

GW = createGridWorld(6,6);

GW.CurrentState = '[6,1]';

GW.TerminalStates = '[2,5]';

GW.ObstacleStates = ["[2,3]";"[2,4]";"[3,5]";"[4,5]"];

%% 设置可否行进

updateStateTranstionForObstacles(GW)

%% 设置reward

nS = numel(GW.States);

nA = numel(GW.Actions);

GW.R = -1*ones(nS,nS,nA);

GW.R(:,state2idx(GW,GW.TerminalStates),:) = 10;

%% 生成环境及初始位置

env = rlMDPEnv(GW);

plot(env)

env.ResetFcn = @() 6;

%% %% SARSA参数初始化

rng(0)

qTable = rlTable(getObservationInfo(env),getActionInfo(env));

tableRep = rlRepresentation(qTable);

tableRep.Options.LearnRate = 1;

agentOpts = rlSARSAAgentOptions;

agentOpts.EpsilonGreedyExploration.Epsilon = 0.04;

sarsaAgent = rlSARSAAgent(tableRep,agentOpts);

trainOpts = rlTrainingOptions;

trainOpts.MaxStepsPerEpisode = 50;

trainOpts.MaxEpisodes= 200;

trainOpts.StopTrainingCriteria = "AverageReward";

trainOpts.StopTrainingValue = 11;

trainOpts.ScoreAveragingWindowLength = 30;

%% 训练

trainingStats = train(sarsaAgent,env,trainOpts);

%% 结果展示

plot(env)

env.Model.Viewer.ShowTrace = true;

env.Model.Viewer.clearTrace;

sim(sarsaAgent,env)

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2019-10-15,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 帮你学MatLab 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档