前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >通过图像训练强化学习模型

通过图像训练强化学习模型

作者头像
万木逢春
发布2020-04-16 15:10:12
1.1K0
发布2020-04-16 15:10:12
举报
文章被收录于专栏:帮你学MatLab

通过图像识别observation及reward

在gym中运行atari环境的时候可以选择同一个游戏的内存方式或者图像方式,内存方式直接返回游戏的状态,图像方式返回当前游戏的画面

之前的文章都是在MATLAB或者simulink中直接获取observation及reward,这里我们让环境返回图像,通过神经网络识别图像中的信息

%% 读取环境

ccc

env = rlPredefinedEnv('SimplePendulumWithImage-Continuous');

obsInfo = getObservationInfo(env);

actInfo = getActionInfo(env);

rng(0)

%%

hiddenLayerSize1 = 400;

hiddenLayerSize2 = 300;

%% 初始化agent

imgPath = [

imageInputLayer(obsInfo(1).Dimension,'Normalization','none','Name',obsInfo(1).Name)

convolution2dLayer(10,2,'Name','conv1','Stride',5,'Padding',0)

reluLayer('Name','relu1')

fullyConnectedLayer(2,'Name','fc1')

concatenationLayer(3,2,'Name','cat1')

fullyConnectedLayer(hiddenLayerSize1,'Name','fc2')

reluLayer('Name','relu2')

fullyConnectedLayer(hiddenLayerSize2,'Name','fc3')

additionLayer(2,'Name','add')

reluLayer('Name','relu3')

fullyConnectedLayer(1,'Name','fc4')

];

dthetaPath = [

imageInputLayer(obsInfo(2).Dimension,'Normalization','none','Name',obsInfo(2).Name)

fullyConnectedLayer(1,'Name','fc5','BiasLearnRateFactor',0,'Bias',0)

];

actPath =[

imageInputLayer(actInfo(1).Dimension,'Normalization','none','Name','action')

fullyConnectedLayer(hiddenLayerSize2,'Name','fc6','BiasLearnRateFactor',0,'Bias',zeros(hiddenLayerSize2,1))

];

criticNetwork = layerGraph(imgPath);

criticNetwork = addLayers(criticNetwork,dthetaPath);

criticNetwork = addLayers(criticNetwork,actPath);

criticNetwork = connectLayers(criticNetwork,'fc5','cat1/in2');

criticNetwork = connectLayers(criticNetwork,'fc6','add/in2');

%%

figure

plot(criticNetwork)

%%

criticOptions = rlRepresentationOptions('LearnRate',1e-03,'GradientThreshold',1);

% criticOptions.UseDevice = 'gpu';

critic = rlRepresentation(criticNetwork,obsInfo,actInfo,...

'Observation',{'pendImage','angularRate'},'Action',{'action'},criticOptions);

%%

imgPath = [

imageInputLayer(obsInfo(1).Dimension,'Normalization','none','Name',obsInfo(1).Name)

convolution2dLayer(10,2,'Name','conv1','Stride',5,'Padding',0)

reluLayer('Name','relu1')

fullyConnectedLayer(2,'Name','fc1')

concatenationLayer(3,2,'Name','cat1')

fullyConnectedLayer(hiddenLayerSize1,'Name','fc2')

reluLayer('Name','relu2')

fullyConnectedLayer(hiddenLayerSize2,'Name','fc3')

reluLayer('Name','relu3')

fullyConnectedLayer(1,'Name','fc4')

tanhLayer('Name','tanh1')

scalingLayer('Name','scale1','Scale',max(actInfo.UpperLimit))

];

dthetaPath = [

imageInputLayer(obsInfo(2).Dimension,'Normalization','none','Name',obsInfo(2).Name)

fullyConnectedLayer(1,'Name','fc5','BiasLearnRateFactor',0,'Bias',0)

];

%%

actorNetwork = layerGraph(imgPath);

actorNetwork = addLayers(actorNetwork,dthetaPath);

actorNetwork = connectLayers(actorNetwork,'fc5','cat1/in2');

%%

actorOptions = rlRepresentationOptions('LearnRate',1e-04,'GradientThreshold',1);

% actorOptions.UseDevice = 'gpu';

actor = rlRepresentation(actorNetwork,obsInfo,actInfo,'Observation',{'pendImage','angularRate'},'Action',{'scale1'},actorOptions);

figure

plot(actorNetwork)

%%

agentOptions = rlDDPGAgentOptions(...

'SampleTime',env.Ts,...

'TargetSmoothFactor',1e-3,...

'ExperienceBufferLength',1e6,...

'DiscountFactor',0.99,...

'MiniBatchSize',128);

agentOptions.NoiseOptions.Variance = 0.6;

agentOptions.NoiseOptions.VarianceDecayRate = 1e-6;

agent = rlDDPGAgent(actor,critic,agentOptions);

%% 设置训练参数

maxepisodes = 5000;

maxsteps = 400;

trainingOptions = rlTrainingOptions(...

'MaxEpisodes',maxepisodes,...

'MaxStepsPerEpisode',maxsteps,...

'Plots','training-progress',...

'StopTrainingCriteria','AverageReward',...

'StopTrainingValue',-740);

plot(env)

%% 并行学习设置

trainingOptions.UseParallel = true;

trainingOptions.ParallelizationOptions.Mode = "async";

trainingOptions.ParallelizationOptions.DataToSendFromWorkers = "Experiences";

trainingOptions.ParallelizationOptions.StepsUntilDataIsSent = -1;

%% 训练

trainingStats = train(agent,env,trainingOptions);

%% 结果展示

simOptions = rlSimulationOptions('MaxSteps',500);

experience = sim(env,agent,simOptions);

% bdclose(mdl)

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2020-04-14,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 帮你学MatLab 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档