33import inspect
44import os
55import sys
6+ import warnings
67
78from keras .src import backend as backend_module
89from keras .src .api_export import keras_export
@@ -124,14 +125,22 @@ def set_backend(backend):
124125
125126 Example:
126127
127- ```python
128- import keras
129-
130- keras.config.set_backend("jax")
131-
132- del keras
133- import keras
134- ```
128+ >>> import os
129+ >>> os.environ["KERAS_BACKEND"] = "tensorflow"
130+ >>>
131+ >>> import keras
132+ >>> from keras import ops
133+ >>> type(ops.ones(()))
134+ <class 'tensorflow.python.framework.ops.EagerTensor'>
135+ >>>
136+ >>> keras.config.set_backend("jax")
137+ UserWarning: Using `keras.config.set_backend` is dangerous...
138+ >>> del keras, ops
139+ >>>
140+ >>> import keras
141+ >>> from keras import ops
142+ >>> type(ops.ones(()))
143+ <class 'jaxlib.xla_extension.ArrayImpl'>
135144
136145 ⚠️ WARNING ⚠️: Using this function is dangerous and should be done
137146 carefully. Changing the backend will **NOT** convert
@@ -143,7 +152,7 @@ def set_backend(backend):
143152
144153 This includes any function or class instance that uses any Keras
145154 functionality. All such code needs to be re-executed after calling
146- `set_backend()` and re-importing the `keras` module .
155+ `set_backend()` and re-importing all imported `keras` modules .
147156 """
148157 os .environ ["KERAS_BACKEND" ] = backend
149158 # Clear module cache.
@@ -164,3 +173,16 @@ def set_backend(backend):
164173 module_name = module_name [module_name .find ("'" ) + 1 :]
165174 module_name = module_name [: module_name .find ("'" )]
166175 globals ()[key ] = importlib .import_module (module_name )
176+
177+ warnings .warn (
178+ "Using `keras.config.set_backend` is dangerous and should be done "
179+ "carefully. Already-instantiated objects will not be converted. Thus, "
180+ "any layers / tensors / etc. already created will no longer be usable "
181+ "without errors. It is strongly recommended not to keep around any "
182+ "Keras-originated objects instances created before calling "
183+ "`set_backend()`. This includes any function or class instance that "
184+ "uses any Keras functionality. All such code needs to be re-executed "
185+ "after calling `set_backend()` and re-importing all imported `keras` "
186+ "modules." ,
187+ stacklevel = 2 ,
188+ )
0 commit comments