8000 MultiHeadAttention quietly ignores masking when annotated with tf.function · Issue #56 · keras-team/tf-keras · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
MultiHeadAttention quietly ignores masking when annotated with tf.function #56
Closed
@tsdeng

Description

@tsdeng

System information.

  • Have I written custom code (as opposed to using a stock example script provided in Keras): Yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): MacOS 13.2.1
  • TensorFlow installed from (source or binary): binary
  • TensorFlow version (use command below): 2.10
  • Python version: 3.10
  • Bazel version (if compiling from source): NA
  • GPU model and memory: Apple Silicone M2 Max
  • Exact command to reproduce:

Run following script in a Jupyter notebook. The outputs of with_mask and without_mask are the same. Mask is not working. When @tf.function is used.

import tensorflow as tf
class ApplyMHA(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__()
        self.mha = tf.keras.layers.MultiHeadAttention(**kwargs)

    @tf.function
    def call(self, x):
        return self.mha(query=x, key=x, value=x)
x = tf.convert_to_tensor(
    [[1,2,0,0],
     [2,3,1,0]
    ]
)

# Create 2 embedding tables, one with masking, one without.
# They use the same seed as initializer so the two tables should be identical except masking.
initializer = tf.keras.initializers.RandomUniform(
    minval=-0.05, maxval=0.05, seed=123)

embedding_table_with_mask = tf.keras.layers.Embedding(input_dim=100, output_dim=3, mask_zero=True, embeddings_initializer=initializer)
embedding_table_without_mask = tf.keras.layers.Embedding(input_dim=100, output_dim=3, mask_zero=False, embeddings_initializer=initializer)
embedding_with_mask = embedding_table_with_mask(x)
print("embedding_with_mask:", embedding_with_mask)
embedding_without_mask = embedding_table_without_mask(x)
print("embedding_without_mask:", embedding_without_mask)

mha = ApplyMHA(num_heads=2, key_dim=3)
print("===After applying MHA====")

# The outputs of with_mask and without_mask are the same. Mask is not working.
print("with_mask:", mha(embedding_with_mask))
print("without_mask:", mha(embedding_without_mask))

Describe the problem.

In the script provided above, when MultiHeadAttention is called from a function annotated with @tf.function, it quietly ignores the masks from the input tensor. This is demonstrated by the results from "with_mask" and "without_mask" being the same.

Notice this problem only happens when calling from a context that's using @tf.function. Removing @tf.function will make this issue disappear.

IIUC the root cause of this is MultiHeadAttention(MHA) relies on _keras_mask attached to the input tensor(code). _keras_mask is not available in the context of @tf.function. To fix the problem, MHA should rely on the mask variable passed to the call function, but this will be a pretty big change.

Describe the current behavior.
When MultiHeadAttention is called from a function annotated with @tf.function, it quietly ignores the masks from the input tensor.

Describe the expected behavior.
MultiHeadAttention should respect the masking from inputs when being called from a function annotated with @tf.function

Contributing.

  • Do you want to contribute a PR? (yes/no):
    yes.
    Change the call signature of MHA to accept an array of inputs and their masks. Do not rely on tensor._keras_mask

Standalone code to reproduce the issue.

Mentioned in Exact command to reproduce section. Here is a colab notebook to demonstrate the issue.

Source code / logs.

Mentioned in Exact command to reproduce section.

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions

    0