在PyTorch库中,可以定义具有某些初始值的神经网络参数nn.Parameter
,例如,
some_param = nn.Parameter(data=torch.rand(4,4))
在我的例子中,我想对这个参数强制执行一些结构。例如,考虑一个严格的下三角形式(在矩阵参数的情况下),因此some_param
的形式如下:
[ 0 0 0 0 ]
[a21 0 0 0 ]
[a31 a32 0 0 ]
[a41 a42 a43 0 ]
但是,如果我用
some_param = nn.Parameter(data=torch.tril(torch.rand(4,4),-1))
然后在训练中对角线上的零点可以得到非零.如何确保此参数在培训期间保持其结构?
发布于 2022-03-07 09:02:37
您可以显式地存储6个参数,然后动态构建完整的矩阵:
class StructuredParameter(nn.Module):
def __init__(self, ...):
self.explicit_p = nn.Parameter(torch.rand(6,))
self.tril_ind = torch.tril_indice(4, 4, -1)
def forward(self, ...):
# before using some_param - create it
some_param = torch.zeros(4, 4)
some_params[self.tril_ind] = self.explicit_p
# use some_param ...
https://stackoverflow.com/questions/71382161
复制