添加链接
link之家
链接快照平台
  • 输入网页链接,自动生成快照
  • 标签化管理网页链接
Collectives™ on Stack Overflow

Find centralized, trusted content and collaborate around the technologies you use most.

Learn more about Collectives

Teams

Q&A for work

Connect and share knowledge within a single location that is structured and easy to search.

Learn more about Teams

Recently, i have a problem about how to use past_key_values which is generated by gpt2. Here is the demo.

# last_hidden_states, h_s in short
# past_key_values, p_k_v in short
# A_h_s.shape = (bs, A_len, hs=768)
(_, A_p_k_v, A_h_s) = gpt2_model(A_input_ids, A_token_type_ids, A_attention_mask, A_position_ids)
# B_h_s.shape = (bs, B_len, hs=768)
(_, B_p_k_v, B_h_s) = gpt2_model(B_input_ids, B_token_type_ids, B_attention_mask, A_position_ids)
# Do some operations on A_h_s, such as integrating some external knowledge
A_h_s = do_something(A_h_s)  # (bs, A_len, hs=768)
The following parts are the problem.
During the training, I hope to be able to use A_h_s and B_h_s to predict C.
What am I supposed to do?
# (bs, A_len + B_len + C_len)
attention_mask = torch.cat([A_attention_mask, B_attention_mask, C_attention_mask], dim=-1)
position_ids = torch.cumsum(C_attention_mask , dim=-1)[:,-c_len:].type_as(C_input_ids) - 1
past_key_values = ?????
# use the lm_logits and C_label_ids to do cross_entropy and get loss for backward
lm_logits, *_ = gpt2_model(C_input_ids, C_token_type_ids, attention_mask, past_key_values)

Maybe, i can torch.cat([A_h_s, B_h_s], dim=-1), torch.cat([A_atten_mask, B_atten_mask], dim=-1). Then feed them to gpt2 to get the past_key_values. Am i right?

Thanks for contributing an answer to Stack Overflow!

  • Please be sure to answer the question. Provide details and share your research!

But avoid

  • Asking for help, clarification, or responding to other answers.
  • Making statements based on opinion; back them up with references or personal experience.

To learn more, see our tips on writing great answers.