我想要使用functorch.jacrev计算BertForMaskedLM的雅克比矩阵,我的尝试方案:
import numpy as np
from transformers import BertTokenizer,BertForMaskedLM
import torch
import torch.nn as nn
from functorch import make_functional, make_functional_with_buffers, vmap, vjp, jvp, jacrev
device = 'cuda:2'
torch.cuda.empty_cache()
model_name = 'bert-base-chinese'
tokenizer = BertTokenizer.from_pretrained(model_name)
bert_model = BertForMaskedLM.from_pretrained(model_name)
net = bert_model.to(device)
fnet, params, buffers = make_functional_with_buffers(net)
def fnet_single(params,x,y):
result = fnet(params, buffers, x.unsqueeze(0).unsqueeze(0),y.unsqueeze(0).unsqueeze(0))['logits']
return result.squeeze(0).squeeze(0)
text = u'江苏省苏州市读者马玉兰有一个在外地上学的朋友'
inputs = tokenizer.encode_plus(text)
segment_ids = inputs['token_type_ids']
token_ids = inputs['input_ids']
length = len(token_ids) - 2
batch_token_ids = torch.tensor([token_ids] * (2 * length - 1),requires_grad=True).to(device)
batch_segment_ids = torch.zeros_like(batch_token_ids).to(device)
for i in range(length):
if i > 0:
batch_token_ids[2 * i - 1, i] = 103
batch_token_ids[2 * i - 1, i + 1] = 103
batch_token_ids[2 * i, i + 1] = 103
threshold = 100
word_token_ids = [[token_ids[1]]]
for i in range(1, length):
x,y = batch_token_ids[2 * i],batch_segment_ids[2*i]
jacobian1 = jacrev(fnet_single,argnums=1)(params,x,y)
x,y = batch_token_ids[2 * i - 1],batch_segment_ids[2*i-1]
jacobian2 = jacrev(fnet_single,argnums=1)(params,x,y)
print(jacobian1,end='-----------------jacobian1-----------------\n')
print(jacobian2,end='-----------------jacobian2-----------------\n')
Traceback (most recent call last): File "study_jacrev.py", line 49, in batch_token_ids = torch.tensor([token_ids] * (2 * length - 1),requires_grad=True).to(device) RuntimeError: Only Tensors of floating point and complex dtype can require gradients
可以计算出BertForMaskedLM的雅克比矩阵(梯度)
期待您的答复!
相似问题