前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >MATLAB借助openai gym环境训练强化学习模型

MATLAB借助openai gym环境训练强化学习模型

作者头像
万木逢春
发布2020-07-31 17:28:58
1.8K0
发布2020-07-31 17:28:58
举报
文章被收录于专栏:帮你学MatLab

虽然openai的gym强化学习环境底层绘图库是pyglet,不太方便自定义,但是已有的环境还是很好用的,有了前面的python环境准备之后,只需要安装gym就可以

pip install gym

这样就可以使用这三个大类的环境了

algorithmic

toy_text

classic_control

我们感兴趣的是classic_control,涉及物理环境,不需要在MATLAB中重新建模

这里我们在gym的MountainCar环境中训练

首先建立环境

代码语言:javascript
复制
classdef MountainCarEnv < rl.env.MATLABEnvironment
    %MountainCarEnv: matlab的MountainCar环境.
    %% 属性设置
    properties
        show=true;
        % pygame环境对象
        p
        % 初始状态
        State
    end
    properties(Access = protected)
        % 结束标记
        IsDone = false
    end
    %% 必须的方法
    methods
        % 构造方法
        function this = MountainCarEnv()
            % 初始设置观察状态
            ObservationInfo = rlNumericSpec([1 2]);
            % 设置动作
            ActionInfo = rlFiniteSetSpec(1:3);
            % 继承系统环境
            this = this@rl.env.MATLABEnvironment(ObservationInfo,ActionInfo);
            % 初始化、设置
            this.State=[0 0];
            this.p=py.gym.make('MountainCar-v0');
            this.p.reset();
            notifyEnvUpdated(this);
        end
        % 一次动作的效果
        function [Observation,Reward,IsDone,LoggedSignals] = step(this,action)
            LoggedSignals = [];
            act = py.int(action-1);
            % 计算reward
            temp = cell(this.p.step(act));
            Observation = double(temp{1,1});
            IsDone = temp{1,3};
            Reward=(1+Observation(1))^2;
            if Observation(1)>=0.5
                Reward=1000;
            end
            this.State = Observation;
            this.IsDone = IsDone;
            notifyEnvUpdated(this);
        end
        % 环境重置
        function InitialObservation = reset(this)
            this.p.reset();
            InitialObservation =[0 0];
            this.State = InitialObservation;
            notifyEnvUpdated(this);
        end
    end
    %% 可选函数、为了方便自行添加的
    methods
        % 收到绘图通知开始绘图的方法
        function isDone=is_done(this)
            % 设置是否需要绘图
            isDone = this.IsDone;
        end
    end
    methods (Access = protected)
        % 收到绘图通知开始绘图的方法
        function envUpdatedCallback(this)
            % 设置是否需要绘图
            if this.show
                this.p.render();
            end
        end
    end
end

接下来就是建立强化学习网络模型

代码语言:javascript
复制
%% 读取环境
ccc
env = MountainCarEnv;
% 获取可观察的状态
obsInfo = getObservationInfo(env);
% 获取可观察的状态维度
numObservations = obsInfo.Dimension(2);
% 获取可执行的动作
actInfo = getActionInfo(env);
% 获取可执行的动作维度
numActions = actInfo.Dimension(1);
rng(0)
%% 初始化agent
statePath = [
    imageInputLayer([1 numObservations 1],'Normalization','none','Name','state')
    fullyConnectedLayer(24,'Name','CriticStateFC1')
    reluLayer('Name','CriticRelu1')
    fullyConnectedLayer(24,'Name','CriticStateFC3')];
actionPath = [
    imageInputLayer([numActions 1 1],'Normalization','none','Name','action')
    fullyConnectedLayer(24,'Name','CriticActionFC1')];
commonPath = [
    additionLayer(2,'Name','add')
    reluLayer('Name','CriticCommonRelu')
    fullyConnectedLayer(1,'Name','output')];
criticNetwork = layerGraph(statePath);
criticNetwork = addLayers(criticNetwork, actionPath);
criticNetwork = addLayers(criticNetwork, commonPath);    
criticNetwork = connectLayers(criticNetwork,'CriticStateFC3','add/in1');
criticNetwork = connectLayers(criticNetwork,'CriticActionFC1','add/in2');
% figure
% plot(criticNetwork)
criticOpts = rlRepresentationOptions('LearnRate',0.01,'GradientThreshold',1);
critic = rlRepresentation(criticNetwork,obsInfo,actInfo,'Observation',{'state'},'Action',{'action'},criticOpts);
agentOpts = rlDQNAgentOptions(...
    'UseDoubleDQN',false, ...    
    'TargetUpdateMethod',"periodic", ...
    'TargetUpdateFrequency',4, ...   
    'ExperienceBufferLength',10000, ...
    'DiscountFactor',0.99, ...
    'MiniBatchSize',128);
agent = rlDQNAgent(critic,agentOpts);
%% 设置训练参数
trainOpts = rlTrainingOptions(...
    'MaxEpisodes', 500, ...
    'MaxStepsPerEpisode', 200, ...
    'Verbose', false, ...
    'Plots','training-progress',...
    'StopTrainingCriteria','AverageReward',...
    'StopTrainingValue',1000); 
%% 训练
代码语言:javascript
复制
% env.show=false;

trainingStats = train(agent,env,trainOpts);
%% 结果展示
env.show=true;
simOptions = rlSimulationOptions('MaxSteps',500);
experience = sim(env,agent,simOptions);
totalReward = sum(experience.Reward);
本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2020-07-28,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 帮你学MatLab 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档