8000 Better support for initializing variables that depend on each other · Issue #4920 · tensorflow/tensorflow · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Better support for initializing variables that depend on each other #4920

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
yaroslavvb opened this issue Oct 12, 2016 · 31 comments
Closed

Better support for initializing variables that depend on each other #4920

yaroslavvb opened this issue Oct 12, 2016 · 31 comments
Labels
stat:contribution welcome Status - Contributions welcome type:feature Feature requests

Comments

@yaroslavvb
Copy link
Contributor
yaroslavvb commented Oct 12, 2016

Currently there's no easy way to properly initialize variables when some variables initial values depend on other variables, and initialization has to be split over several .run calls. This kind of initialization happens with data-dependent parameter init.

This could be solved if there were something like var.initialized_value(), but which only runs initializer if the variable hasn't been initialized already.

Example:

a = tf.get_variable("a", shape=())
b = tf.get_variable("b", initializer=a.initialized_value())
c = tf.placeholder(tf.float32, shape=())
d = tf.get_variable("d", initializer=a.initialized_value()+c)

sess.run([a.initializer, b.initializer])
sess.run([d.initializer], feed_dict={c: 0})
sess.run([a, b, d])

Out[]: [0.30858743, -1.2756943, 0.30858743]

Here, a and b end up with different values because initializer for a has been run twice, which is counter-intuitive, the user expects a,b,d to have same values

@yaroslavvb
Copy link
Contributor Author

BTW, here's a work-around using @purpledog graph editor + tf.is_variable_initialized() to make initialize_all_variables work the same way whether or not initialized_value was used, and which works when initialization is split over several session runs. It works by adding temporary dependency on undocumented "variable caching node" <varname>/read. A more robust approach might be to add initializer_once to variables, which is like initializer but is only has effect when variable is not initialized, and add tf.initialize_all_variables_once that will initialize all uninitialized variables in correct order

@drpngx drpngx added enhancement stat:contribution welcome Status - Contributions welcome labels Oct 24, 2016
@DjangoPeng
Copy link
Contributor

@yaroslavvb Would you mind me fixing the bug? I'm working around theinitialization this day. @drpngx

@yaroslavvb
Copy link
Contributor Author

I'm not on TensorFlow team, but since it's "Contributions Welcome" it sounds like they'd be open to a fix. Maybe a way to start would be with initializer_once construct that checks if variable is initialized before running

@DjangoPeng
Copy link
Contributor
DjangoPeng commented Oct 25, 2016

@yaroslavvb You said initializer of a has been run twice, do you mean b and d both call a.initialized_value()? Then it makes d == a in the second call, while b equals the first call of a, isn't it?

@DjangoPeng
Copy link
Contributor
DjangoPeng commented Oct 25, 2016

Thanks for your adv 8000 ise on initilzier_once.
While I find a interesting output of your example by using tf.initialize_all_variables(), it works well.

a = tf.get_variable("a", shape=())
b = tf.get_variable("b", initializer=a.initialized_value())
c = tf.placeholder(tf.float32, shape=())
d = tf.get_variable("d", initializer=a.initialized_value()+c)

sess = tf.Session()
sess.run(tf.initialize_all_variables(), feed_dict={c: 0})
sess.run([a, b, d])

Out[1]:  [1.1102008, 1.1102008, 1.1102008]

@yaroslavvb
Copy link
Contributor Author
yaroslavvb commented Oct 25, 2016

Which version?
I tried in 0.10 and 0.11rc1 with same result:

import tensorflow as tf
a = tf.get_variable("a", shape=())
b = tf.get_variable("b", initializer=a.initialized_value())
c = tf.placeholder(tf.float32, shape=())
d = tf.get_variable("d", initializer=a.initialized_value()+c)
sess = tf.Session()
sess.run([a.initializer, b.initializer])
sess.run([d.initializer], feed_dict={c: 0})
sess.run([a, b, d])

Out[1]: [1.5315183, -0.65960622, 1.5315183]

@DjangoPeng
Copy link
Contributor

Do you tried the sess.run(tf.initialize_all_variables(), feed_dict={c: 0})?
If you initialize them by tf.initialize_all_variables() at once, then you'll get the right outputs.

@yaroslavvb
Copy link
Contributor Author
yaroslavvb commented Oct 26, 2016

Correct, this issue only happens when when initialization has to be split over several .run calls. Such use case arises when you need to run some data through parts of your graph in order to determine initial value for variables in another part (ie, in data-dependent initializations)

@DjangoPeng
Copy link
Contributor

But I think TF should support separate initializations for every variable. At least, it shouldn't give the wrong initialized_value as @yaroslavvb mentioned. What do you think, @drpngx ? Do we need to open a PR to support that?

@drpngx
Copy link
Contributor
drpngx commented Oct 28, 2016

Yes, I think it would help.

@eamartin
Copy link

Are there any PRs associated with this?

I've written code several times that does something like

x = tf.placeholder(tf.float32, [100])
v0 = tf.get_variable('v0', initializer=tf.zeros_like(x))
z = x * v0
v1 = tf.get_variable('v1', initializer=tf.zeros_like(z))

This cannot be initialized with tf.initialize_all_variables() because v1's initializer depends on v0 (but only the shape). I've hacked around this by writing a static_zeros_like function that only uses shape info from graph definition time (tensor.get_shape().as_list()).

It seems like the clean solution is to do a toposort on the graph consisting of all of the predecessors of any initializer. Executing the "initializer graph" nodes in this toposorted order will guarantee safe initialization.

@ZhimingZhou
Copy link

HI @yaroslavvb

I'm using the smart_initialize, but I found it re-do the initialization each time you call the session run.

Seem the edge or something like is not removed. I checked that the control_inputs are actually cleared, but there should be something else. I'm not very familiar with tensor-flow. Do you have any idea what is going on?

Thanks.

@yaroslavvb
Copy link
Contributor Author

Can you provide a reproducible example? I'm assuming you are referring to smart_initialize from the gist: https://gist.github.com/yaroslavvb/d592394c0cedd32513f8fbb87ca05938

@ZhimingZhou
Copy link
ZhimingZhou commented Jan 18, 2017

Hi, here is a example:

# testing variable order init

import tensorflow as tf


def initialize_all_variables(sess=None):
    """Initializes all uninitialized variables in correct order. Initializers
    are only run for uninitialized variables, so it's safe to run this multiple
    times.
    Args:
        sess: session to use. Use default session if None.
    """

    from tensorflow.contrib import graph_editor as ge
    def make_initializer(var):
        def f():
            return tf.assign(var, var.initial_value).op

        return f

    def make_noop():
        return tf.no_op()

    def make_safe_initializer(var):
        """Returns initializer op that only runs for uninitialized ops."""
        return tf.cond(tf.is_variable_initialized(var), make_noop, make_initializer(var), name="safe_init_" + var.op.name).op

    if not sess:
        sess = tf.get_default_session()
    g = tf.get_default_graph()

    safe_initializers = {}
    for v in tf.all_variables():
        safe_initializers[v.op.name] = make_safe_initializer(v)

    # initializers access variable vaue through read-only value cached in
    # <varname>/read, so add control dependency to trigger safe_initializer
    # on read access
    for v in tf.all_variables():
        var_name = v.op.name
        var_cache = g.get_operation_by_name(var_name + "/read")
        ge.reroute.add_control_inputs(var_cache, [safe_initializers[var_name]])

    sess.run(tf.group(*safe_initializers.values()))

    # remove initializer dependencies to avoid slowing down future variable reads
    for v in tf.all_variables():
        var_name = v.op.name
        var_cache = g.get_operation_by_name(var_name + "/read")
        ge.reroute.remove_control_inputs(var_cache, [safe_initializers[var_name]])


################################################################################
# Tests
################################################################################

def my_test():

    def linear_wn(input, output_size, stddev=0.05):

        def linear_wn_initializer_g(v, init_scale=1.0):
            v_norm = tf.nn.l2_normalize(v, [0])
            x_init = tf.matmul(input, v_norm)
            m_init, v_init = tf.nn.moments(x_init, [0])
            scale_init = init_scale / tf.sqrt(v_init + 1e-10)
            scale_init = tf.Print(scale_init, [0], 'linear ini running')
            return scale_init

        def linear_wn_initializer_b(v, init_scale=1.0):
            v_norm = tf.nn.l2_normalize(v, [0])
            x_init = tf.matmul(input, v_norm)
            m_init, v_init = tf.nn.moments(x_init, [0])
            scale_init = init_scale / tf.sqrt(v_init + 1e-10)
            return -m_init * scale_init

        if len(input.get_shape()) > 2:
            input = tf.reshape(input, [input.get_shape().as_list()[0], -1])

        v = tf.get_variable('v', [input.get_shape().as_list()[1], output_size], initializer=tf.truncated_normal_initializer(stddev=stddev), trainable=True)
        g = tf.get_variable('g', dtype=tf.float32, initializer=linear_wn_initializer_g(v.initialized_value()), trainable=True)
        b = tf.get_variable('b', dtype=tf.float32, initializer=linear_wn_initializer_b(v.initialized_value()), trainable=True)

        x = tf.matmul(input, v)
        scaler = g / tf.sqrt(tf.reduce_sum(tf.square(v), [0]))
        x = tf.reshape(scaler, [1, output_size]) * x + tf.reshape(b, [1, output_size])

        return x

    input = tf.get_variable('input', [1, 100], initializer=tf.truncated_normal_initializer(stddev=1.0), trainable=True)
    output = linear_wn(input, 10)

    sess = tf.InteractiveSession()
    initialize_all_variables(sess)

    opt = tf.train.GradientDescentOptimizer(learning_rate=0.0001).minimize(output)

    while True:
        sess.run([opt])

if __name__ == '__main__':
    my_test()

print("Tests passed")

It will output:

I tensorflow/core/kernels/logging_ops.cc:79] linear ini running[0]
I tensorflow/core/kernels/logging_ops.cc:79] linear ini running[0]
I tensorflow/core/kernels/logging_ops.cc:79] linear ini running[0]
I tensorflow/core/kernels/logging_ops.cc:79] linear ini running[0]
I tensorflow/core/kernels/logging_ops.cc:79] linear ini running[0]
I tensorflow/core/kernels/logging_ops.cc:79] linear ini running[0]
I tensorflow/core/kernels/logging_ops.cc:79] linear ini running[0]
I tensorflow/core/kernels/logging_ops.cc:79] linear ini running[0]
I ten
8000
sorflow/core/kernels/logging_ops.cc:79] linear ini running[0]
I tensorflow/core/kernels/logging_ops.cc:79] linear ini running[0]
................

@yaroslavvb
Copy link
Contributor Author

@ZhimingZhou thanks for the easy to reproduce test case. It seems the semantics of variables changed. I would add dependency on "variable/read" node to force conditional initialization subgraph to run when "variable" is being read. But now this no longer works. I tried adding dependency on var.read_value(), but that creates new nodes on each evaluation.

@alextp is there some way to force some op to run when variable is being read?

@alextp
Copy link
Contributor
alextp commented Jan 18, 2017 via email

@yaroslavvb
Copy link
Contributor Author

@ZhimingZhou here's a fixed version using @alextp suggestion -- http://pastebin.com/5zqcc5Q9

This needs a bit more thought though, I don't like the idea of users needing to keep separate tensors for variable write and variable read. Ideally one would be able to do a=tf.Variable(b+c) and have this work, running regardless of whether b,c is a tensor or a variable, with initialization order being handled automatically

@eamartin
Copy link
eamartin commented Jan 18, 2017

@yaroslavvb I think the nicest solution here is having a toposorted list of the variables, and then making a Session.run call for each variables initializer. It might be possible to avoid the overhead of a bunch of Session.run calls by making a chain of control dependencies between initializers, but I doubt it since it appears data dependencies aren't obeyed during initialization.

I was starting to write a toposort routine, but I realized the graph is constructed with ancestors first and the TF collections are appended to, which means tf.global_variables is already toposorted. This is a little bit of hack since it's not specified that TF collections are appended to, but I think it will always work to do something like

for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
    sess.run(v.initializer)

I'm running TensorFlow 0.10, but here's a non-trivial dependent initialization that works with this approach:

x = tf.Variable(np.random.randn(5), name='x')
y = tf.Variable(x, name='y')
z = tf.Variable(tf.zeros_like(x) + y, name='z')
w = tf.Variable(2 * z, name='w')

with tf.Session() as sess:
    for v in tf.get_collection(tf.GraphKeys.VARIABLES):
        sess.run(v.initializer)

It seems that the larger problem (why these workarounds are needed) is that Session treats initializers (or maybe just running Ops instead of Tensors) differently and doesn't obey data or control dependencies.

@alextp
Copy link
Contributor
alextp commented Jan 18, 2017 via email

@yaroslavvb
Copy link
Contributor Author

@alextp the problem with initialized_value is that it reruns initializers. This causes inconsistent results when your variable initialization has to be split over several session.run calls.

A different snag is this situation -- your code does calculations based on Tensor a. Later you decide use Variable for a instead of Tensor. Now you have to go through your code and replace some instances of a with a.initialized_value()

@eamartin -- I think that's a reasonable replacement. First of all, overhead of session.run is something like 100 usec, and you probably won't have more than 1000 layers in your topological sort. Secondly, I think you could collapse it into a single .run call by using tf.while loop

@eamartin
Copy link
eamartin commented Jan 18, 2017

@alextp
That breaks the (useful) abstraction of not having to worry about whether a tensor is a variable or just a usual tensor.

Using initialized_value isn't very nice for cases like

x = tf.Variable(5.0)
y = x + 1.0
z = tf.Variable(tf.zeros_like(y))

Beyond this, you can imagine dependent initialization spread out much further across a graph. I can't set y = x.initialized_value() + 1.0 because I want y = x + 1.0 on later graph runs (after I mutate x). Using initialized_value would require separate code paths (/ subgraphs) for initializing and running the graph.

I think I understand the problem with the default behavior. In my case, z's initializer has a data dependency that goes back to x. However, x is still uninitialized at this time, which causes an error. I attempted to work around this by feeding in the initialized_value for variable.value(), but this isn't a workaround because you can't always evaluate initialized_value if the initializer depends on other non-initialized variables. If there was a mode of Session.run which would default to initial variable values, then this strategy could work. Sadly, this cannot be done with tf.select and an "init_mode" placeholder because tf.select(true, v.initialized_value(), v) fails for uninitialized v.

@alextp
Copy link
Contributor
alextp commented Jan 18, 2017 via email

@yaroslavvb
Copy link
Contributor Author

@eamartin actually, maybe easier solution than toposort with multiple session.run calls is to have a wrapper around variable that recovers previous behavior (separate variable/read op which you can use to trigger initialization on read). Then you could do var = wrap_variable(var), and that would give you an op that runs variable initialization first time it's read using tf.cond(var.is_initialized)

@ZhimingZhou
Copy link

@yaroslavvb Thank for the new version. But I found it still re-do the initialization. There seems still have some modifications to the graph.

BTW. The assert on 'unique tf.identity' seems not hold. In my case, some node have more than one identity. And adam optimizer's moving mean and variance parameter have zero consumer. [Strange, but it should be the case, according to the debug info]

@alextp @eamartin @yaroslavvb
I still have not found a efficient way to do the initialization, when it depends input data, and involving date flow. For example:

@eamartin 
x = tf.Variable(5.0)
y = x + 1.0
z = tf.Variable(tf.zeros_like(y))

initial_value or initialized_value() seems not suitable when involving data manipulation outside variable initializer. But this is essential when people want to do data dependent initialization.

The following works, but it could be too slow. It rerun the graph for each variable.

@eamartin 
for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
    sess.run(v.initializer)

As you mentioned, @eamartin @yaroslavvb, I guess a simple workaround solution is run the variable initialization in order with one sess.run(). I tried but failed to finish it. For example, I tried adding all the variable in order, but it seems not work. Could you help? Thanks.

@eamartin
Copy link
eamartin commented Jan 21, 2017 via email

@aselle aselle added type:feature Feature requests and removed enhancement labels Feb 9, 2017
@alextp
Copy link
Contributor
alextp commented Mar 23, 2017

Please take a look at the new behavior of Variable.initialized_value (introduced in b25d1c7 ). Now it doesn't force initialization, so it can be safely used when initializing variables from other variables.

The only overhead of using this instead of the raw Variable is that it will trigger recomputing the initializer in the beginning of the step even if the variable has already been initialized (this is mildly annoying to fix but it's doable).

So it should be kosher to use both in a single and in multiple session.run calls to initialize stuff.

@yaroslavvb
Copy link
Contributor Author

Progress!

BTW, we use Variables that depend on each other, and some are initialized from placeholders. Recomputing the initializer means you have to feed placeholder values that aren't getting used.

As a work-around I've been recommending the pattern below which uses graph_editor + Switch + Merge to rewrite the graph to make initializer execution lazy.

https://gist.github.com/yaroslavvb/d67410e240369736fc4ba0267250ef27

@alextp
Copy link
Contributor
alextp commented Mar 23, 2017 via email

@girving
Copy link
Contributor
girving commented Jun 16, 2017

@yaroslavvb Is this resolved as of b25d1c7?

@girving
Copy link
Contributor
girving commented Jun 16, 2017

Closing for now since it looks like yes, but happy to reopen if not.

@girving girving closed this as completed Jun 16, 2017
@TheButlah
Copy link
TheButlah commented Jul 11, 2017

I think this should be reopened. Initializing variables from other variables that depend on placeholders is still really difficult. It would be ideal if it were achievable in a single run call, without depending on the graph editor method that @yaroslavvb provided. Although the behavior of initialized_value changed so that it uses tf.cond to decide whether or not to run the initializer, this is still not ideal for cases such as below:

    sess = tf.InteractiveSession()

    x = tf.placeholder(tf.int32, (), name='X')
    a = tf.Variable(x, name='A', trainable=False)  # First cached variable (depends on placeholder)
    b = a + 1  # Pretend this is a huge amount of code
    c = tf.Variable(b+1, name='C', trainable=False)  # Second cached variable
    d = b + 3  # Pretend this is a bunch more code
    final_output = d + c

    is_init = tf.is_variable_initialized
    with tf.control_dependencies([a.initializer]):
        is_a = tf.Print(is_init(a), [is_init(a)], "A: ")
        with tf.control_dependencies([is_a]):
            with tf.control_dependencies([c.initializer]):
                is_c = tf.Print(is_init(c), [is_init(c)], "C: ")
                init_op = is_c.op

    init_op.run(feed_dict={x: 1})

    print(a.eval())
    print(c.eval())

That code, seemingly randomly, sometimes fails and sometimes works. Using initial_value doesn't help because then the placeholder would have to always be filled in subsequent runs, and using initialized_value doesn't work for the same reason (also because while its ok to do the tf.cond call once for initialization, repeatedly calling it for each training step for a more complex model would be very slow). Or is there a new, better way that I am not aware of?

copybara-service bot pushed a commit that referenced this issue Aug 16, 2023
Imported from GitHub PR openxla/xla#4920

Support AllToAll in fp8 gemm pattern matching.
Copybara import of the project:

--
74eafe0ea1c366918581dc2c8e2a88d8b32e53f4 by shuw <shuw@nvidia.com>:

Support AllToAll in fp8 gemm pattern matching

Merging this change closes #4920

PiperOrigin-RevId: 557513647
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stat:contribution welcome Status - Contributions welcome type:feature Feature requests
Projects
None yet
Development

No branches or pull requests

9 participants
0