Collectives™ on Stack Overflow
Find centralized, trusted content and collaborate around the technologies you use most.
Learn more about Collectives
Teams
Q&A for work
Connect and share knowledge within a single location that is structured and easy to search.
Learn more about Teams
Recently, i have a problem about how to use past_key_values which is generated by gpt2. Here is the demo.
# last_hidden_states, h_s in short
# past_key_values, p_k_v in short
# A_h_s.shape = (bs, A_len, hs=768)
(_, A_p_k_v, A_h_s) = gpt2_model(A_input_ids, A_token_type_ids, A_attention_mask, A_position_ids)
# B_h_s.shape = (bs, B_len, hs=768)
(_, B_p_k_v, B_h_s) = gpt2_model(B_input_ids, B_token_type_ids, B_attention_mask, A_position_ids)
# Do some operations on A_h_s, such as integrating some external knowledge
A_h_s = do_something(A_h_s) # (bs, A_len, hs=768)
The following parts are the problem.
During the training, I hope to be able to use A_h_s and B_h_s to predict C.
What am I supposed to do?
# (bs, A_len + B_len + C_len)
attention_mask = torch.cat([A_attention_mask, B_attention_mask, C_attention_mask], dim=-1)
position_ids = torch.cumsum(C_attention_mask , dim=-1)[:,-c_len:].type_as(C_input_ids) - 1
past_key_values = ?????
# use the lm_logits and C_label_ids to do cross_entropy and get loss for backward
lm_logits, *_ = gpt2_model(C_input_ids, C_token_type_ids, attention_mask, past_key_values)
Maybe, i can torch.cat([A_h_s, B_h_s], dim=-1)
, torch.cat([A_atten_mask, B_atten_mask], dim=-1)
. Then feed them to gpt2 to get the past_key_values
. Am i right?
Thanks for contributing an answer to Stack Overflow!
- Please be sure to answer the question. Provide details and share your research!
But avoid …
- Asking for help, clarification, or responding to other answers.
- Making statements based on opinion; back them up with references or personal experience.
To learn more, see our tips on writing great answers.