前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >线程池管理的pipeline设计模式(用了“精进C++”里的内容)

线程池管理的pipeline设计模式(用了“精进C++”里的内容)

作者头像
用户9831583
发布2022-12-04 16:29:30
1.2K0
发布2022-12-04 16:29:30
举报
文章被收录于专栏:码出名企路

记录最近算法工程里开发的pipeline设计模式。优化了上一版本:

1,增加了线程池管理,每个node可以异步处理;

2,增加了callback,将最后一个node的结果callback到主程序,避免的参数传递的冗余实现;

3,去掉了模板类设计,避免只能在头文件中去实现的弊端;

4,去掉了前node的输出就是后node的输入,避免函数返回值带来复制的开销的应用;

/** @ 带有线程池的pipeline pipeline里的Node可以异步执行,加快处理速度 */

task_queue.h

/** @ 线程池的任务队列 @ 入队和出队 */

代码语言:javascript
复制
template<class T>
class TaskQueue
{
    public:
        TaskQueue() = default;
        ~TaskQueue() = default;

        //任务入队
        void enqueue(T& t)
        {
            std::unique_lock<std::mutex> lock(m_mutex);
            if(m_pNextQueue)
            {
                m_pNextQueue->enqueue(t);
                return;
            }
            m_queue.push(t);
        }

        //任务出队
        bool dequeue(T& t)
        {
            std::unique_lock<std::mutex> lock(m_mutex);
            if(m_queue.empty())
                return false;

            t = std::move(m_queue.front());
            m_queue.pop();
            return true;
        }

        int32_t size()
        {
            std::unique_lock<std::mutex> lock(m_mutex);
            return m_queue.szie();
        }

        bool empty()
        {
            std::unique_lock<std::mutex> lock(m_mutex);
            return m_queue.empty();
        }

        //出队等待
        bool dequeue_wait(T& t,uint32_t timeout)
        {
            std::unique_lock<std::mutex> lock(m_mutex);
            if(m_queue.empty())
                m_cond.wait_for(lock,std::chrono::milliseconds(timeout));

            if(m_queue.empty())
                return false;

            t = std::move(m_queue.front());
            m_queue.pop();
            return true;
        }

        //取出taskQueue对象
        void connect(TaskQueue<T>* pQueue)
        {
            std::unique_lock<std::mutex> lock(m_mutex);
            m_pNextQueue = pQueue;
        }

    private:

        std::queue<T> m_queue;
        std::mutex m_mutex;
        std::condition_variable m_cond;
        TaskQueue<T>* m_pNextQueue;
};

thread_manager.h

/** @ 线程管理 */

代码语言:javascript
复制
static const uint32_t MaxThreadNums = 8;
class ThreadManager
{
    public:
        ThreadManager(const int m_threads = MaxThreadNums ):m_threads(std::vector<std::thread>(m_threads)),m_shutdown(false){

        }

        ~ThreadManager(){
            this->shutdown();
        }

        ThreadManager(ThreadManager &&)=delete;
        ThreadManager(const ThreadManager &)=delete;
        ThreadManager &operator=(ThreadManager &&)=delete;
        ThreadManager &operator=(const ThreadManager &) =delete;

        void init()
        {
            for(uint32_t i =0; i < m_threads.size();++i)
            {
                m_threads.at(i) = std::thread(ThreadWorker(this,i));
            }
        }

        void shutdown()
        {
            m_shutdown = true;
            m_cond.notify_all();
            for(uint32_t i =0; i < m_threads.size(); ++i)
            {
                if(m_threads.at(i).joinable())
                {
                    m_threads.at(i).join();
                }
            }
        }

        template<typename F,typename... Args>
        auto postJobs(F&& f, Args &&...args)->std::future<decltype(f(args...))>
        {
            std::function<decltype(f(args...))()> func = std::bind(std::forward<F>(f),std::forward<Args>...);
            auto task_ptr = std::make_shared<std::packaged_task<decltype(f(args...))()>>(func);

            std::function<void()> warpper_func = [task_ptr]()
            {
                (*task_ptr);
            };

            m_task_queue.push(warpper_func);
            m_cond.notify_one();

            return task_ptr->get_future();
        }

    private:

        class ThreadWorker
        {
            public:
                ThreadWorker(ThreadManager *pThreadManager,const int32_t tid):m_pThreadManager(pThreadManager),m_tid(tid){

                };

                void operator()()
                {
                    std::function<void()> task;

                    bool dequeued = false;
                    while(!m_pThreadManager->m_shutdown)
                    {
                        std::unique_lock<std::mutex> lock(m_pThreadManager->m_mutex);
                        m_pThreadManager->m_cond.wait(lock,[&](){
                            return !m_pThreadManager->m_task_queue.empty();
                        });

                        m_pThreadManager->m_task_queue.pop();
                        lock.unlock();

                        task();
                    }
                }

            private:
                int32_t m_tid;
                ThreadManager *m_pThreadManager;
        };

    private:
        bool m_shutdown;
        std::mutex m_mutex;
        std::condition_variable m_cond;
        std::vector<std::thread> m_threads;
        std::queue<std::function<void()>> m_task_queue;

};

common_struct.h

/** @ pipeline的入参结构体 */

代码语言:javascript
复制
enum NodeType
{
    Source,
    Channel,
    Sink
};
struct NodeNeedInfo
{
    std::string name;
    NodeType type;
};
struct InputRequestInfo
{
    bool isOK;
    uint32_t requestId;
    
    //nodeInput Info

    NodeNeedInfo nodeInfo[8];
};
using NodeNeedInfoPtr = std::shared_ptr<NodeNeedInfo>;
using InputRequestInfoPtr = std::shared_ptr<InputRequestInfo>;
using ResultCallback = std::function<void(const InputRequestInfoPtr&)>;

struct PipelineDescriptor
{
    uint32_t nums;
    std::string name;
    
    //NodeInfo
    NodeNeedInfo nodes[8];
    ResultCallback callback;
};
using PipelineDescriptorPtr = std::shared_ptr<PipelineDescriptor>;

node.h

//node.h : base Node /*** @ 1, 去掉了类模板 @ 2, 不需要上一级的输出是下一级的输入 @ 3, 通过callback的方式将最后一级的结果输出给前一级 */

代码语言:javascript
复制
class Node
{
    public:
        Node(): m_stop(false),m_is_sink(false){};
        virtual ~Node() = default;

        virtual int32_t initialize(const std::string& conf) = 0;
        virtual int32_t process(InputRequestInfoPtr pRequestInfo) = 0;

        virtual  std::string getNodeName() const= 0;

        virtual NodeType Type()const =0;
    

    public:

        void start()
        {
            //起线程处理
            m_thread = std::thread([this](){
                executeRequest();
            });
        }

        void stop()
        {
            m_stop = true;
            if(m_thread.joinable())
            {
                m_thread.join();
            }
        }

        // inline std::string getNodeName() const
        // {
        //     return m_node_name;
        // }

        void executeRequest()
        {   
            int count = 0;
            while(!m_stop)
            {
                InputRequestInfoPtr pRequest;
                if(m_input_queue.dequeue(pRequest))
                {
                    int32_t ret = process(pRequest);

                    if(ret != 0)
                    {
                        ///////////
                    }

                    //set request for next node
                    if(m_type != NodeType::Sink)//bug to do
                    {
                        count++;
                        m_output_queue.enqueue(pRequest);
                    }
                    else
                    {
                        m_result_callback(pRequest);//回到main: publishResult
                    }
                 
                    
                }
                else
                {
                    ////////////
                }
            }
        }

        TaskQueue<InputRequestInfoPtr> &input_queue()
        {
            return m_input_queue;
        }

        TaskQueue<InputRequestInfoPtr> &output_queue()
        {
            return m_output_queue;
        }

        // inline NodeType Type()const
        // {
        //     return m_type;
        // }

        void callbackRegister(ResultCallback callback)
        {
            m_result_callback = std::move(callback);
        }

    private:
        bool m_stop;
        bool m_is_sink;
        bool m_source;
        TaskQueue<InputRequestInfoPtr> m_input_queue;
        TaskQueue<InputRequestInfoPtr> m_output_queue;
        std::thread  m_thread;
        std::string m_node_name;
        ResultCallback m_result_callback;
        NodeType m_type;

};

nodeA/B

/** NodeA -> NodeB -> NodeC */

代码语言:javascript
复制
class Node_A :public Node
{
    public:
        Node_A() = default;
        ~Node_A() =default;

        int32_t initialize(const std::string& conf)override{
            std::cout<<"I am NodeA initialize"<<std::endl;
            return 0;
        }

        int32_t process(InputRequestInfoPtr pRequestInfo)override{
            std::cout<<"I am NodeA process"<<std::endl;
            pRequestInfo->requestId = 100;
            return 0;
        }

        std::string getNodeName()const override
        {
            return "Node_A";
        }

        NodeType Type()const override
        {
            return NodeType::Source;
        }

};

//NodeB
class Node_B :public Node
{
    public:
        Node_B() = default;
        ~Node_B() =default;

        int32_t initialize(const std::string& conf)override{
            std::cout<<"I am NodeB initialize"<<std::endl;
            return 0;
        }

        int32_t process(InputRequestInfoPtr pRequestInfo)override{
            std::cout<<"I am NodeB process"<<std::endl;
            return 0;
        }

        std::string getNodeName()const override
        {
            return "Node_B";
        }

        NodeType Type()const override
        {
            return NodeType::Sink;
        }

};

perceptionPipeline.h

/** @ 一个具体的pipeline */

代码语言:javascript
复制
class PerceptionPipeline
{

    public:
        PerceptionPipeline()=default;
        ~PerceptionPipeline()=default;

        /**
        @ submit request to source node
        */
        void submit(InputRequestInfoPtr& pRequest)
        {
            m_pNodes[0]->input_queue().enqueue(pRequest);
        }

        /**
        @ initialize an pipline
        */
        int32_t initialize(const PipelineDescriptorPtr& pPipelineDesc)
        {
            int32_t result = 0;
            result = createNodes(pPipelineDesc);

            return result;
        }
        
        int32_t createNodes(const PipelineDescriptorPtr& pPipelineInfo)
        {
            int32_t result = 0;
            for(uint32_t i=0 ; i < pPipelineInfo->nums; i++)
            {
                //todo factory create nodes
                std::shared_ptr<Node> pNode = std::move(CreateNode(pPipelineInfo->nodes[i]));

                result = pNode->initialize("lxk");

                if(0!=result)
                {
                    //////////////
                    break;
                }

                if(pNode->Type() == NodeType::Sink)
                {   
                    std::cout<<"------------callbackRegister-----------"<<std::endl;
                    pNode->callbackRegister(pPipelineInfo->callback);
                }

                this->addNode(pNode);
            }

            return result;
        }
        static std::shared_ptr<Node> CreateNode(const NodeNeedInfo& node_desc)
        {
            if(node_desc.name == "NodeA")
                return (std::make_shared<Node_A>());
            if(node_desc.name == "NodeB")
                return (std::make_shared<Node_B>());

            return nullptr;
        }

        void start()
        {
            for(auto i:m_pNodes)
            {
                i->start();
            }
        }

        void stop()
        {
            for(auto i:m_pNodes)
            {
                i->stop();
            }
        }

        std::string PipelineInfo()
        {
            std::stringstream sstr;
            sstr<<"\n";
            sstr<<"-------Pipeline info  start----------\n";
            sstr<<"number of nodes: "<<m_pNodes.size()<<"\n";
            for(uint32_t i =0; i <m_pNodes.size(); i++)
            {
                if(i == m_pNodes.size() -1)
                {
                    sstr<<m_pNodes[i]->getNodeName()<<"\n";
                }
                else
                {
                    sstr<<m_pNodes[i]->getNodeName()<<"->";
                }
            }

            sstr<<"----------Pipeline info end----------\n";

            return sstr.str();
        }

    private:

        void addNode(std::shared_ptr<Node>& pNode)
        {
            std::shared_ptr<Node> pTail = nullptr;
            if(!m_pNodes.empty())
            {
                pTail = m_pNodes.back();
            }
            m_pNodes.push_back(pNode);

            //connect output queue node with input queue of next node
            if(pTail)
            {
                pTail->output_queue().connect(&pNode->input_queue());
            }
        }

    private:
        std::vector<std::shared_ptr<Node>> m_pNodes;
};

CameraPerception

/** @ 实际测试案例 */

代码语言:javascript
复制
class CameraPerception
{
    public:
        CameraPerception();
        ~CameraPerception();

        bool init();

    private:

        void cameraPerceptionCallback();

        void publishResult(const InputRequestInfoPtr& pInferResult);

        void MarkObstacleOnImage(uint64_t request_id);

        std::unique_ptr<PerceptionPipeline> m_perception_pipeline;
        std::unique_ptr<ThreadManager> m_thread_manager;

};

CameraPerception::CameraPerception()
{

}

CameraPerception::~CameraPerception()
{
    m_perception_pipeline->stop();
}

bool CameraPerception::init()
{
    m_thread_manager.reset(new ThreadManager());
    m_thread_manager->init();

    m_perception_pipeline.reset(new PerceptionPipeline);
    
    PipelineDescriptorPtr pPipeline(new PipelineDescriptor);
    pPipeline->name = "perception pipeline";

    int count = 2;

    pPipeline->nodes[0].name = "NodeA";
    pPipeline->nodes[0].type = NodeType::Source;

    pPipeline->nodes[1].name = "NodeB";
    pPipeline->nodes[1].type = NodeType::Sink;
    pPipeline->nums = count;
    
    pPipeline->callback = std::bind(&CameraPerception::publishResult, this, std::placeholders::_1);
    int32_t ret = m_perception_pipeline->initialize(pPipeline);

    if(ret != 0)
    {
        std::cout<<"pipeline init error";
        return false;
    }
  
    m_perception_pipeline->start();

    std::cout<<"pipeline info: "<<m_perception_pipeline->PipelineInfo()<<std::endl;

    cameraPerceptionCallback();


    return ret;
}


void CameraPerception::cameraPerceptionCallback()
{
    InputRequestInfoPtr input_info(new InputRequestInfo);
    input_info->requestId = 1;
    input_info->isOK = true;

    for(size_t i =0; i < 3; i++)
    {
        input_info->nodeInfo[i].name = "lxkkk";
    }
    
    m_perception_pipeline->submit(input_info);

}

//callbacked by node
void CameraPerception::publishResult(const InputRequestInfoPtr& pInferResult)
{
    std::cout<<"publishResult:  ID: "<<pInferResult->requestId<<std::endl;

    m_thread_manager->postJobs(std::bind(&CameraPerception::MarkObstacleOnImage, this,pInferResult->requestId));
}

void CameraPerception::MarkObstacleOnImage(uint64_t request_id)
{
     std::cout<<"MarkObstacleOnImage:  ID: "<<request_id<<std::endl;
}
int main()
{
    std::unique_ptr<CameraPerception> pCameraPerceptionHandle(new CameraPerception());
    pCameraPerceptionHandle->init();
}
本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2022-11-22,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 码出名企路 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • task_queue.h
  • thread_manager.h
  • common_struct.h
  • node.h
  • nodeA/B
  • perceptionPipeline.h
  • CameraPerception
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档