首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
社区首页 >问答首页 >不一致的错误: TypeError: super( type,obj):obj必须是类型的实例或子类型

不一致的错误: TypeError: super( type,obj):obj必须是类型的实例或子类型
EN

Stack Overflow用户
提问于 2019-05-30 23:36:15
回答 1查看 1.9K关注 0票数 2

我有下面两个类的python代码。

代码语言:javascript
运行
AI代码解释
复制
import torch
import torch.nn as nn
import torch.nn.functional as F

class QNet_baseline(nn.Module):

    """
        A MLP with 2 hidden layer

        observation_dim (int): number of observation features
        action_dim (int): Dimension of each action
        seed (int): Random seed
    """

    def __init__(self, observation_dim, action_dim, seed):
        super(QNet_baseline, self).__init__()
        self.seed = torch.manual_seed(seed)
        self.fc1 = nn.Linear(observation_dim, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, action_dim)

    def forward(self, observations):
        """
           Forward propagation of neural network

        """

        x = F.relu(self.fc1(observations))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


class QNet_3hidden(nn.Module):

    """
        A MLP with 3 hidden layer

        observation_dim (int): number of observation features
        action_dim (int): Dimension of each action
        seed (int): Random seed
    """

    def __init__(self, observation_dim, action_dim, seed):
        super(QNet_3hidden, self).__init__()
        self.seed = torch.manual_seed(seed)
        self.fc1 = nn.Linear(observation_dim, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 64)
        self.fc4 = nn.Linear(64, action_dim)

    def forward(self, observations):
        """
           Forward propagation of neural network

        """

        x = F.relu(self.fc1(observations))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        return x

我使用相同的代码来实例化这两个类。QNet_baseline工作正常,但我得到了QNet_3hidden的以下错误。为什么QNet_baseline会工作,但QNet_3hidden有一个错误?我错过了什么?谢谢!

代码语言:javascript
运行
AI代码解释
复制
/home/workspace/QNetworks.py in __init__(self, observation_dim, action_dim, seed)
     44 
     45     def __init__(self, observation_dim, action_dim, seed):
---> 46         super(QNet_3hidden, self).__init__()
     47         self.seed = torch.manual_seed(seed)
     48         self.fc1 = nn.Linear(observation_dim, 128)

TypeError: super(type, obj): obj must be an instance or subtype of type

另外,下面是两个类如何实例化:

代码语言:javascript
运行
AI代码解释
复制
class DDQN_Agent():
    """Interacts with and learns from the environment.

    Attributes:
        state_size (int): dimension of each state
        action_size (int): dimension of each action
        seed (int): random seed
    """

    def __init__(self, state_size, action_size, seed, qnet="baseline", filename=None):
        """Initialize an Agent object.

        Args:
            filename: path of .pth file with trained weights
        """
        self.state_size = state_size
        self.action_size = action_size
        self.seed = random.seed(seed)

        # Q-Network
        if qnet=="3hidden":
            self.qnetwork_local = QNet_3hidden(state_size, action_size, seed).to(device)
            self.qnetwork_target = QNet_3hidden(state_size, action_size, seed).to(device)
            self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=LR)
        else:
            self.qnetwork_local = QNet_baseline(state_size, action_size, seed).to(device)
            self.qnetwork_target = QNet_baseline(state_size, action_size, seed).to(device)
            self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=LR)

        if filename:
            weights = torch.load(filename)
            self.qnetwork_local.load_state_dict(weights)
            self.qnetwork_target.load_state_dict(weights)

        # Replay memory
        self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed)
        # Initialize time step (for updating every UPDATE_EVERY steps)
        self.t_step = 0
EN

回答 1

Stack Overflow用户

发布于 2021-05-08 03:00:56

我遇到了一个类似的问题,完全重新启动内核起了作用。如本keitakurita的评论中所建议的

您是否正在运行木星笔记本中的代码,而没有重新启动内核?如果是这样的话,您的内核可能引用了错误的类。

我怀疑这可能是因为我在重写类后遇到了错误。

这也解释了为什么这是一个难以复制的错误。下面列出了一些类似的问题,以帮助跟踪相同的问题:

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/56390705

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档
查看详情【社区公告】 技术创作特训营有奖征文