首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >在std::map中插入dtype

在std::map中插入dtype
EN

Stack Overflow用户
提问于 2022-06-24 09:53:54
回答 1查看 52关注 0票数 0

我想要绘制一张包含一对pybind11::dtypeint的地图,并将其映射为OpenCV格式:

代码语言:javascript
复制
static std::map<std::pair<pybind11::dtype, int>, int> ocv_types;

因此,我insert编辑了所有的组合,但是在添加int32_tfloat_t时似乎存在问题。

代码语言:javascript
复制
    ocv_types.insert(std::make_pair(std::make_pair(pybind11::dtype::of<std::int32_t>() , 3), CV_32SC3));

    ocv_types.insert(std::make_pair(std::make_pair(pybind11::dtype::of<std::float_t>() , 3), CV_32FC3));

当我这样做时,只有 CV_32SC3才是真正的insert编辑,my猜测在某个地方程序“认为”两个元素是相等的,因此不会插入第二个元素。

我怎么才能把这2加起来呢?

P.S. I这样做只是为了“证明”类型不相等:

代码语言:javascript
复制
    if(pybind11::dtype::of<std::int32_t>() == pybind11::dtype::of<std::float_t>())
    {
        std::cout << "std::int32_t == std::float_t" << std::endl;
    }
    else
    {
        std::cout << "std::int32_t != std::float_t" << std::endl;
    }

..。当然他们不是。

编辑

我为<添加了dtype函数,并在地图的比较函数中使用了它,但并非所有元素都存在于地图中:

代码语言:javascript
复制
int getVal(pybind11::dtype type)
{
    if(type.is(pybind11::dtype::of<std::uint8_t>()))
        return 1;
    if(type.is(pybind11::dtype::of<std::uint16_t>()))
        return 2;
    if(type.is(pybind11::dtype::of<std::int16_t>()))
        return 3;
    if(type.is(pybind11::dtype::of<std::int32_t>()))
        return 4;
    if(type.is(pybind11::dtype::of<std::float_t>()))
        return 5;
    if(type.is(pybind11::dtype::of<std::double_t>()))
        return 6;
}

inline bool operator <(const pybind11::dtype a, const pybind11::dtype b) //friend claim has to be here
{
    return getVal(a) < getVal(b);
}

auto comp = [](const std::pair<pybind11::dtype, int> a, const std::pair<pybind11::dtype, int> b)
{
    return a < b;
};
static std::map<std::pair<pybind11::dtype, int>, int, decltype(comp)> ocv_types(comp);
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-06-24 14:57:43

正如您所指出的,pybind11::dtype没有任何特定的订单。因此,IMO的最佳方法是使用std::unordered_map并提供相应的散列。pybind11已经有了一些散列函数,因此需要在std::hash中采用它。

下面是我编写的测试(使用Catch2),它在我的机器上通过:

main.cpp:

代码语言:javascript
复制
#include "catch2/catch_all.hpp"
#include <pybind11/embed.h>
#include <pybind11/numpy.h>
#include <unordered_map>

template<>
struct std::hash<pybind11::dtype>
{
    size_t operator()(const pybind11::dtype &t) const
    {
        return pybind11::hash(t);
    }
};

template<>
struct std::hash<std::pair<pybind11::dtype, int>>
{
    size_t operator()(const std::pair<pybind11::dtype, int> &t) const
    {
        return std::hash<pybind11::dtype>{}(t.first) ^ static_cast<size_t>(t.second);
    }
};


TEST_CASE("map_with_dtype") {
    constexpr auto CV_32SC3 = 1;
    constexpr auto CV_32FC3 = 2;

    pybind11::scoped_interpreter guard{};

    std::unordered_map<std::pair<pybind11::dtype, int>, int> ocv_types;
    REQUIRE(ocv_types.empty());

    auto a = ocv_types.insert(std::make_pair(std::make_pair(pybind11::dtype::of<std::int32_t>() , 3), CV_32SC3));
    REQUIRE(a.second);

    auto b = ocv_types.insert(std::make_pair(std::make_pair(pybind11::dtype::of<std::float_t>() , 3), CV_32FC3));
    REQUIRE(b.second);
    CHECK(b.first->second == CV_32FC3);

    CHECK(ocv_types.size() == 2);
}

CMakeLists.txt:

代码语言:javascript
复制
cmake_minimum_required(VERSION 3.16)

# set the project name
project(MapOfPyBind11)

find_package(Catch2 REQUIRED)
find_package(pybind11 REQUIRED)

# add the executable
add_executable(MapOfPyBind11Test main.cpp)
target_link_libraries(MapOfPyBind11Test PRIVATE Catch2::Catch2 pybind11::module pybind11::embed)

include(CTest)
include(Catch)
catch_discover_tests(MapOfPyBind11Test)
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/72742207

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档