Description
System information.
- Have I written custom code (as opposed to using a stock example script provided in Keras): Yes
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): MacOS 13.2.1
- TensorFlow installed from (source or binary): binary
- TensorFlow version (use command below): 2.10
- Python version: 3.10
- Bazel version (if compiling from source): NA
- GPU model and memory: Apple Silicone M2 Max
- Exact command to reproduce:
Run following script in a Jupyter notebook. The outputs of with_mask and without_mask are the same. Mask is not working. When @tf.function
is used.
import tensorflow as tf
class ApplyMHA(tf.keras.layers.Layer):
def __init__(self, **kwargs):
super().__init__()
self.mha = tf.keras.layers.MultiHeadAttention(**kwargs)
@tf.function
def call(self, x):
return self.mha(query=x, key=x, value=x)
x = tf.convert_to_tensor(
[[1,2,0,0],
[2,3,1,0]
]
)
# Create 2 embedding tables, one with masking, one without.
# They use the same seed as initializer so the two tables should be identical except masking.
initializer = tf.keras.initializers.RandomUniform(
minval=-0.05, maxval=0.05, seed=123)
embedding_table_with_mask = tf.keras.layers.Embedding(input_dim=100, output_dim=3, mask_zero=True, embeddings_initializer=initializer)
embedding_table_without_mask = tf.keras.layers.Embedding(input_dim=100, output_dim=3, mask_zero=False, embeddings_initializer=initializer)
embedding_with_mask = embedding_table_with_mask(x)
print("embedding_with_mask:", embedding_with_mask)
embedding_without_mask = embedding_table_without_mask(x)
print("embedding_without_mask:", embedding_without_mask)
mha = ApplyMHA(num_heads=2, key_dim=3)
print("===After applying MHA====")
# The outputs of with_mask and without_mask are the same. Mask is not working.
print("with_mask:", mha(embedding_with_mask))
print("without_mask:", mha(embedding_without_mask))
Describe the problem.
In the script provided above, when MultiHeadAttention is called from a function annotated with @tf.function
, it quietly ignores the masks from the input tensor. This is demonstrated by the results from "with_mask" and "without_mask" being the same.
Notice this problem only happens when calling from a context that's using @tf.function
. Removing @tf.function
will make this issue disappear.
IIUC the root cause of this is MultiHeadAttention(MHA) relies on _keras_mask
attached to the input tensor(code). _keras_mask
is not available in the context of @tf.function
. To fix the problem, MHA should rely on the mask
variable passed to the call
function, but this will be a pretty big change.
Describe the current behavior.
When MultiHeadAttention is called from a function annotated with @tf.function
, it quietly ignores the masks from the input tensor.
Describe the expected behavior.
MultiHeadAttention should respect the masking from inputs when being called from a function annotated with @tf.function
- Do you want to contribute a PR? (yes/no):
yes.
Change the call signature of MHA to accept an array of inputs and their masks. Do not rely ontensor._keras_mask
Standalone code to reproduce the issue.
Mentioned in Exact command to reproduce
section. Here is a colab notebook to demonstrate the issue.
Source code / logs.
Mentioned in Exact command to reproduce
section.