首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >ValueError:轴与数组Pytorch转置操作不匹配

ValueError:轴与数组Pytorch转置操作不匹配
EN

Stack Overflow用户
提问于 2020-05-25 02:03:26
回答 2查看 509关注 0票数 0

我正在尝试将PIL图像转换为torch变量类型。这是它的代码:

代码语言:javascript
运行
复制
def preprocess_image(pil_im, resize_im=True):
    """
        Processes image for CNNs
    Args:
        PIL_img (PIL_img): PIL Image or numpy array to process
        resize_im (bool): Resize to 224 or not
    returns:
        im_as_var (torch variable): Variable that contains processed float tensor
    """
    # mean and std list for channels (Imagenet)
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]

    #ensure or transform incoming image to PIL image
    if type(pil_im) != Image.Image:
        try:
            pil_im = Image.fromarray(pil_im)
        except Exception as e:
            print("could not transform PIL_img to a PIL Image object. Please check input.")

    # Resize image
    if resize_im:
        pil_im = pil_im.resize((224, 224), Image.ANTIALIAS)

    im_as_arr = np.float32(pil_im)
    print(im_as_arr.shape)
    im_as_arr = im_as_arr.transpose(2, 0, 1)  # Convert array to D,W,H
    # Normalize the channels
    for channel, _ in enumerate(im_as_arr):
        im_as_arr[channel] /= 255
        im_as_arr[channel] -= mean[channel]
        im_as_arr[channel] /= std[channel]
    # Convert to float tensor
    im_as_ten = torch.from_numpy(im_as_arr).float()
    # Add one more channel to the beginning. Tensor shape = 1,3,224,224
    im_as_ten.unsqueeze_(0)
    # Convert to Pytorch variable
    im_as_var = Variable(im_as_ten, requires_grad=True)
    return im_as_var

original_image =  Image.open('blahblah.jpeg')
prep_img = preprocess_image(original_image)

我收到一个错误,它说

代码语言:javascript
运行
复制
ValueError                                Traceback (most recent call last)
<ipython-input-27-08cf62156870> in <module>()
      1 original_image =  Image.open('blahblah.jpeg')
----> 2 prep_img = preprocess_image(original_image)

<ipython-input-22-ad146391ce9d> in preprocess_image(pil_im, resize_im)
    155     im_as_arr = np.float32(pil_im)
    156     print(im_as_arr.shape)
--> 157     im_as_arr = im_as_arr.transpose(2, 0, 1)  # Convert array to D,W,H
    158     # Normalize the channels
    159     for channel, _ in enumerate(im_as_arr):

ValueError: axes don't match array

我认为转置操作有一些问题。但我在另一个用例上测试了它,这是完美的。不确定是什么触发了此错误。

EN

回答 2

Stack Overflow用户

发布于 2020-05-25 02:45:46

您的图像没有三维,这意味着它不是RGB图像,而是其他图像,例如灰度图像,从技术上讲,它只有一个通道,但NumPy会自动删除单个维度,因此不是具有大小H,W,1,而是具有大小H,W。

由于您似乎需要RGB图像,因此可以通过使用PIL转换图像来确保图像处于RGB模式。

代码语言:javascript
运行
复制
pil_img = pil_img.convert("RGB")
票数 0
EN

Stack Overflow用户

发布于 2020-05-25 02:56:31

pil_im = pil_im.resize((224, 224), Image.ANTIALIAS)改变了数组,使得它没有第三维(即灰度)。

im_as_arr = im_as_arr.transpose(2, 0, 1)假设图像是三维的(即RBG)。

在调用转置方法之前,您需要将图像转换为RBG并更改调整大小的尺寸;通过Michael Jungo建议的方法

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/61990168

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档