8000 demo"movielens-1m-keras-with-horovo.py" run failed with "Exception: Optimizer type is not supported! got <class 'keras.src.optimizers.adam.Adam'>" · Issue #487 · tensorflow/recommenders-addons · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

demo"movielens-1m-keras-with-horovo.py" run failed with "Exception: Optimizer type is not supported! got <class 'keras.src.optimizers.adam.Adam'>" #487

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
kingofstorm opened this issue Mar 3, 2025 · 4 comments

Comments

@kingofstorm
Copy link

System information

  • docker containers : tfra/dev_container:latest-tf2.15.1-python3.9 CUDA 12.3 CUDNN 8.9
  • tensorflow-2.16.2 kears 3.8
  • TensorFlow-Recommenders-Addons: 0.8.0
  • Python version: 3.9.7
  • Is GPU used? YES (NVIDIA H20)

above all are installed by "pip install tensorflow==2.16.2 tensorflow-recommenders-addons==0.8.0"
and reinstall horovod with "HOROVOD_WITH_TENSORFLOW=1 pip install --no-cache-dir horovod"
Describe the bug

[1,0]:Traceback (most recent call last):
[1,0]: File "/mnt/mfs/mfs6/cvr_tf_fm/code_complie/recommenders-addons-0.8.0/demo/dynamic_embedding/movielens-1m-keras-with-horovod/movielens-1m-keras-with-horovod.py", line 782, in
[1,0]: app.run(main)
[1,0]: File "/usr/local/lib/python3.9/site-packages/absl/app.py", line 308, in run
[1,0]: _run_main(main, args)
[1,0]: File "/usr/local/lib/python3.9/site-packages/absl/app.py", line 254, in _run_main
[1,0]: sys.exit(main(argv))
[1,0]: File "/mnt/mfs/mfs6/cvr_tf_fm/code_complie/recommenders-addons-0.8.0/demo/dynamic_embedding/movielens-1m-keras-with-horovod/movielens-1m-keras-with-horovod.py", line 770, in main
[1,0]: train()
[1,0]: File "/mnt/mfs/mfs6/cvr_tf_fm/code_complie/recommenders-addons-0.8.0/demo/dynamic_embedding/movielens-1m-keras-with-horovod/movielens-1m-keras-with-horovod.py", line 631, in train
[1,0]: optimizer = de.DynamicEmbeddingOptimizer(optimizer, synchronous=True)
[1,0]: File "/usr/local/lib/python3.9/site-packages/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_optimizer.py", line 859, in DynamicEmbeddingOptimizer
[1,0]: raise Exception(f"Optimizer type is not supported! got {str(type(self))}")
[1,0]:Exception: Optimizer type is not supported! got <class 'keras.src.optimizers.adam.Adam'>

Primary job terminated normally, but 1 process returned
a non-zero exit code. Per user-direction, the job has been aborted.

Code to reproduce the issue
use this demo "https://github.com/tensorflow/recommenders-addons/blob/master/demo/dynamic_embedding/movielens-1m-keras-with-horovod/movielens-1m-keras-with-horovod.py" without any change

Other info / logs
run this demo with bash -x start.sh and detail info is as follow:

  • rm -rf ./export_dir
    ++ nvidia-smi --query-gpu=name --format=csv,noheader
    ++ wc -l
  • gpu_num=8
  • export gpu_num
  • horovodrun -np 8 python movielens-1m-keras-with-horovod.py --mode=train --model_dir=./model_dir --export_dir=./export_dir --steps_per_epoch=20000 --shuffle=True

I also run with "python movielens-1m-keras-with-horovod.py --mode=train --model_dir=./model_dir --export_dir=./export_dir --steps_per_epoch=20000 --shuffle=True" and got the same error.

Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.

@PhyCoe
Copy link
PhyCoe commented Apr 3, 2025

In the movielens-1m-keras folder under demo/dynamic_embedding, python movielens-1m-keras.py --mode=train --epochs=1 reports the same error.

OS Env: centos 6.10
tensorflow : 2.16.2
tensorflow-recommenders-addons : 0.8.1

2025-04-03 17:50:15.477693: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:479] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-04-03 17:50:15.518506: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:10575] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-04-03 17:50:15.518576: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1442] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-04-03 17:50:15.546503: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-04-03 17:50:16.561519: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
WARNING:tensorflow:dynamic_embedding.GraphKeys has already been deprecated. The Variable will not be added to collections because it does not actully own any value, but only a holder of tables, which may lead to import_meta_graph failed since non-valued object has been added to collection. If you need to use tf.compat.v1.train.Saver and access all Variables from collection, you could manually add it to the collection by tf.compat.v1.add_to_collections(names, var) instead.
_SingleDeviceSaver removed after tf version 2.15
WARNING:tensorflow:An exception occurred when import horovod.tensorflow: No module named 'horovod'
WARNING:tensorflow:An exception occurred when import horovod.tensorflow: No module named 'horovod'
I0403 17:50:19.371832 140041422329600 dataset_info.py:690] Load dataset info from /home/jms/jupyter_projects/25H1-panhangyi/dataset/movielens/1m-ratings/0.1.1
I0403 17:50:19.378629 140041422329600 reader.py:261] Creating a tf.data.Dataset reading 4 files located in folders: /home/jms/jupyter_projects/25H1-panhangyi/dataset/movielens/1m-ratings/0.1.1.
I0403 17:50:19.475346 140041422329600 logging_logger.py:49] Constructing tf.data.Dataset movielens for split train, from /home/jms/jupyter_projects/25H1-panhangyi/dataset/movielens/1m-ratings/0.1.1
2025-04-03 17:50:19.580683: I ./tensorflow_recommenders_addons/dynamic_embedding/core/kernels/lookup_impl/lookup_table_op_cpu.h:157] HashTable on CPU is created on optimized mode: K=l, V=f, DIM=32, init_size=8192
2025-04-03 17:50:19.597778: I ./tensorflow_recommenders_addons/dynamic_embedding/core/kernels/lookup_impl/lookup_table_op_cpu.h:157] HashTable on CPU is created on optimized mode: K=l, V=f, DIM=32, init_size=8192
Traceback (most recent call last):
File "/home/jms/recommenders-addons/demo/dynamic_embedding/movielens-1m-keras/movielens-1m-keras.py", line 208, in
app.run(main)
File "/root/anaconda3/lib/python3.10/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/root/anaconda3/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/home/jms/recommenders-addons/demo/dynamic_embedding/movielens-1m-keras/movielens-1m-keras.py", line 198, in main
train()
File "/home/jms/recommenders-addons/demo/dynamic_embedding/movielens-1m-keras/movielens-1m-keras.py", line 127, in train
optimizer = de.DynamicEmbeddingOptimizer(optimizer)
File "/root/anaconda3/lib/python3.10/site-packages/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_optimizer.py", line 861, in DynamicEmbeddingOptimizer
raise Exception(f"Optimizer type is not supported! got {str(type(self))}")
Exception: Optimizer type is not supported! got <class 'keras.src.optimizers.adam.Adam'>

@PhyCoe
Copy link
PhyCoe commented Apr 3, 2025

In the movielens-1m-keras folder under demo/dynamic_embedding, python movielens-1m-keras.py --mode=train --epochs=1 reports the same error.

OS Env: centos 6.10 tensorflow : 2.16.2 tensorflow-recommenders-addons : 0.8.1

2025-04-03 17:50:15.477693: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:479] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2025-04-03 17:50:15.518506: E external/local_xla/xla/s 8000 tream_executor/cuda/cuda_dnn.cc:10575] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2025-04-03 17:50:15.518576: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1442] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered 2025-04-03 17:50:15.546503: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags. 2025-04-03 17:50:16.561519: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT WARNING:tensorflow:dynamic_embedding.GraphKeys has already been deprecated. The Variable will not be added to collections because it does not actully own any value, but only a holder of tables, which may lead to import_meta_graph failed since non-valued object has been added to collection. If you need to use tf.compat.v1.train.Saver and access all Variables from collection, you could manually add it to the collection by tf.compat.v1.add_to_collections(names, var) instead. _SingleDeviceSaver removed after tf version 2.15 WARNING:tensorflow:An exception occurred when import horovod.tensorflow: No module named 'horovod' WARNING:tensorflow:An exception occurred when import horovod.tensorflow: No module named 'horovod' I0403 17:50:19.371832 140041422329600 dataset_info.py:690] Load dataset info from /home/jms/jupyter_projects/25H1-panhangyi/dataset/movielens/1m-ratings/0.1.1 I0403 17:50:19.378629 140041422329600 reader.py:261] Creating a tf.data.Dataset reading 4 files located in folders: /home/jms/jupyter_projects/25H1-panhangyi/dataset/movielens/1m-ratings/0.1.1. I0403 17:50:19.475346 140041422329600 logging_logger.py:49] Constructing tf.data.Dataset movielens for split train, from /home/jms/jupyter_projects/25H1-panhangyi/dataset/movielens/1m-ratings/0.1.1 2025-04-03 17:50:19.580683: I ./tensorflow_recommenders_addons/dynamic_embedding/core/kernels/lookup_impl/lookup_table_op_cpu.h:157] HashTable on CPU is created on optimized mode: K=l, V=f, DIM=32, init_size=8192 2025-04-03 17:50:19.597778: I ./tensorflow_recommenders_addons/dynamic_embedding/core/kernels/lookup_impl/lookup_table_op_cpu.h:157] HashTable on CPU is created on optimized mode: K=l, V=f, DIM=32, init_size=8192 Traceback (most recent call last): File "/home/jms/recommenders-addons/demo/dynamic_embedding/movielens-1m-keras/movielens-1m-keras.py", line 208, in app.run(main) File "/root/anaconda3/lib/python3.10/site-packages/absl/app.py", line 308, in run _run_main(main, args) File "/root/anaconda3/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main sys.exit(main(argv)) File "/home/jms/recommenders-addons/demo/dynamic_embedding/movielens-1m-keras/movielens-1m-keras.py", line 198, in main train() File "/home/jms/recommenders-addons/demo/dynamic_embedding/movielens-1m-keras/movielens-1m-keras.py", line 127, in train optimizer = de.DynamicEmbeddingOptimizer(optimizer) File "/root/anaconda3/lib/python3.10/site-packages/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_optimizer.py", line 861, in DynamicEmbeddingOptimizer raise Exception(f"Optimizer type is not supported! got {str(type(self))}") Exception: Optimizer type is not supported! got <class 'keras.src.optimizers.adam.Adam'>

发现在/root/anaconda3/lib/python3.10/site-packages/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_optimizer.py 文件中判断的优化器是属于 Keras 2.x legacy 和 Keras 2.x 版本的优化器接口。
from tf_keras.optimizers.legacy import Optimizer as keras_OptimizerV2_legacy
from tf_keras.optimizers import Optimizer as keras_OptimizerV2

但demo里面import的优化器<keras.src.optimizers.adam.Adam>是属于 Keras 3.x版本的优化器,这个应该怎么解决?

@PhyCoe
Copy link
PhyCoe commented Apr 7, 2025

In the movielens-1m-keras folder under demo/dynamic_embedding, python movielens-1m-keras.py --mode=train --epochs=1 reports the same error.
OS Env: centos 6.10 tensorflow : 2.16.2 tensorflow-recommenders-addons : 0.8.1
2025-04-03 17:50:15.477693: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:479] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2025-04-03 17:50:15.518506: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:10575] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2025-04-03 17:50:15.518576: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1442] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered 2025-04-03 17:50:15.546503: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags. 2025-04-03 17:50:16.561519: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT WARNING:tensorflow:dynamic_embedding.GraphKeys has already been deprecated. The Variable will not be added to collections because it does not actully own any value, but only a holder of tables, which may lead to import_meta_graph failed since non-valued object has been added to collection. If you need to use tf.compat.v1.train.Saver and access all Variables from collection, you could manually add it to the collection by tf.compat.v1.add_to_collections(names, var) instead. _SingleDeviceSaver removed after tf version 2.15 WARNING:tensorflow:An exception occurred when import horovod.tensorflow: No module named 'horovod' WARNING:tensorflow:An exception occurred when import horovod.tensorflow: No module named 'horovod' I0403 17:50:19.371832 140041422329600 dataset_info.py:690] Load dataset info from /home/jms/jupyter_projects/25H1-panhangyi/dataset/movielens/1m-ratings/0.1.1 I0403 17:50:19.378629 140041422329600 reader.py:261] Creating a tf.data.Dataset reading 4 files located in folders: /home/jms/jupyter_projects/25H1-panhangyi/dataset/movielens/1m-ratings/0.1.1. I0403 17:50:19.475346 140041422329600 logging_logger.py:49] Constructing tf.data.Dataset movielens for split train, from /home/jms/jupyter_projects/25H1-panhangyi/dataset/movielens/1m-ratings/0.1.1 2025-04-03 17:50:19.580683: I ./tensorflow_recommenders_addons/dynamic_embedding/core/kernels/lookup_impl/lookup_table_op_cpu.h:157] HashTable on CPU is created on optimized mode: K=l, V=f, DIM=32, init_size=8192 2025-04-03 17:50:19.597778: I ./tensorflow_recommenders_addons/dynamic_embedding/core/kernels/lookup_impl/lookup_table_op_cpu.h:157] HashTable on CPU is created on optimized mode: K=l, V=f, DIM=32, init_size=8192 Traceback (most recent call last): File "/home/jms/recommenders-addons/demo/dynamic_embedding/movielens-1m-keras/movielens-1m-keras.py", line 208, in app.run(main) File "/root/anaconda3/lib/python3.10/site-packages/absl/app.py", line 308, in run _run_main(main, args) File "/root/anaconda3/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main sys.exit(main(argv)) File "/home/jms/recommenders-addons/demo/dynamic_embedding/movielens-1m-keras/movielens-1m-keras.py", line 198, in main train() File "/home/jms/recommenders-addons/demo/dynamic_embedding/movielens-1m-keras/movielens-1m-keras.py", line 127, in train optimizer = de.DynamicEmbeddingOptimizer(optimizer) File "/root/anaconda3/lib/python3.10/site-packages/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_optimizer.py", line 861, in DynamicEmbeddingOptimizer raise Exception(f"Optimizer type is not supported! got {str(type(self))}") Exception: Optimizer type is not supported! got <class 'keras.src.optimizers.adam.Adam'>

发现在/root/anaconda3/lib/python3.10/site-packages/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_optimizer.py 文件中判断的优化器是属于 Keras 2.x legacy 和 Keras 2.x 版本的优化器接口。 from tf_keras.optimizers.legacy import Optimizer as keras_OptimizerV2_legacy from tf_keras.optimizers import Optimizer as keras_OptimizerV2

但demo里面import的优化器<keras.src.optimizers.adam.Adam>是属于 Keras 3.x版本的优化器,这个应该怎么解决?

我换成了tf 2.15.1 ,tfra 0.7.0,能正常运行demo。

@linzb-xyz
Copy link

tensorflow 2.16.1 默认使用 keras 3,可以照下面方法使用 keras 2 试试

To continue using Keras 2.0, do the following.

Install tf-keras via pip install tf-keras~=2.16

To switch tf.keras to use Keras 2 (tf-keras), set the environment variable TF_USE_LEGACY_KERAS=1 directly or in your python program with import os;os.environ["TF_USE_LEGACY_KERAS"]="1". Please note that this will set it for all packages in your Python runtime program

Change the keras import: replace import tensorflow.keras as keras or import keras with import tf_keras as keras. Update any tf.keras references to keras.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants
0