我有下面两个类的python代码。
import torch
import torch.nn as nn
import torch.nn.functional as F
class QNet_baseline(nn.Module):
"""
A MLP with 2 hidden layer
observation_dim (int): number of observation features
action_dim (int): Dimension of each action
seed (int): Random seed
"""
def __init__(self, observation_dim, action_dim, seed):
super(QNet_baseline, self).__init__()
self.seed = torch.manual_seed(seed)
self.fc1 = nn.Linear(observation_dim, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, action_dim)
def forward(self, observations):
"""
Forward propagation of neural network
"""
x = F.relu(self.fc1(observations))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
class QNet_3hidden(nn.Module):
"""
A MLP with 3 hidden layer
observation_dim (int): number of observation features
action_dim (int): Dimension of each action
seed (int): Random seed
"""
def __init__(self, observation_dim, action_dim, seed):
super(QNet_3hidden, self).__init__()
self.seed = torch.manual_seed(seed)
self.fc1 = nn.Linear(observation_dim, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 64)
self.fc4 = nn.Linear(64, action_dim)
def forward(self, observations):
"""
Forward propagation of neural network
"""
x = F.relu(self.fc1(observations))
x = F.relu(self.fc2(x))
x = F.relu(self.fc3(x))
x = self.fc4(x)
return x
我使用相同的代码来实例化这两个类。QNet_baseline
工作正常,但我得到了QNet_3hidden
的以下错误。为什么QNet_baseline
会工作,但QNet_3hidden
有一个错误?我错过了什么?谢谢!
/home/workspace/QNetworks.py in __init__(self, observation_dim, action_dim, seed)
44
45 def __init__(self, observation_dim, action_dim, seed):
---> 46 super(QNet_3hidden, self).__init__()
47 self.seed = torch.manual_seed(seed)
48 self.fc1 = nn.Linear(observation_dim, 128)
TypeError: super(type, obj): obj must be an instance or subtype of type
另外,下面是两个类如何实例化:
class DDQN_Agent():
"""Interacts with and learns from the environment.
Attributes:
state_size (int): dimension of each state
action_size (int): dimension of each action
seed (int): random seed
"""
def __init__(self, state_size, action_size, seed, qnet="baseline", filename=None):
"""Initialize an Agent object.
Args:
filename: path of .pth file with trained weights
"""
self.state_size = state_size
self.action_size = action_size
self.seed = random.seed(seed)
# Q-Network
if qnet=="3hidden":
self.qnetwork_local = QNet_3hidden(state_size, action_size, seed).to(device)
self.qnetwork_target = QNet_3hidden(state_size, action_size, seed).to(device)
self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=LR)
else:
self.qnetwork_local = QNet_baseline(state_size, action_size, seed).to(device)
self.qnetwork_target = QNet_baseline(state_size, action_size, seed).to(device)
self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=LR)
if filename:
weights = torch.load(filename)
self.qnetwork_local.load_state_dict(weights)
self.qnetwork_target.load_state_dict(weights)
# Replay memory
self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed)
# Initialize time step (for updating every UPDATE_EVERY steps)
self.t_step = 0
发布于 2021-05-08 03:00:56
我遇到了一个类似的问题,完全重新启动内核起了作用。如本keitakurita的评论中所建议的
您是否正在运行木星笔记本中的代码,而没有重新启动内核?如果是这样的话,您的内核可能引用了错误的类。
我怀疑这可能是因为我在重写类后遇到了错误。
这也解释了为什么这是一个难以复制的错误。下面列出了一些类似的问题,以帮助跟踪相同的问题:
https://stackoverflow.com/questions/56390705
复制相似问题