@@ -17,6 +17,7 @@ limitations under the License.
1717using Google . Protobuf ;
1818using System ;
1919using System . Collections . Generic ;
20+ using System . Linq ;
2021using static Tensorflow . Binding ;
2122
2223namespace Tensorflow
@@ -176,7 +177,7 @@ private void _init_from_args(object initial_value,
176177 // If 'initial_value' makes use of other variables, make sure we don't
177178 // have an issue if these other variables aren't initialized first by
178179 // using their initialized_value() method.
179- var _initial_value2 = _try_guard_against_uninitialized_dependencies ( _initial_value ) ;
180+ var _initial_value2 = _try_guard_against_uninitialized_dependencies ( name , _initial_value ) ;
180181
181182 _initializer_op = gen_state_ops . assign ( _variable , _initial_value2 , validate_shape ) . op ;
182183
@@ -215,9 +216,9 @@ public Tensor _TensorConversionFunction(TF_DataType dtype = TF_DataType.DtInvali
215216 /// Attempt to guard against dependencies on uninitialized variables.
216217 /// </summary>
217218 /// <param name="initial_value"></param>
218- private Tensor _try_guard_against_uninitialized_dependencies ( Tensor initial_value )
219+ private Tensor _try_guard_against_uninitialized_dependencies ( string name , Tensor initial_value )
219220 {
220- return _safe_initial_value_from_tensor ( initial_value , new Dictionary < string , Operation > ( ) ) ;
221+ return _safe_initial_value_from_tensor ( name , initial_value , op_cache : new Dictionary < string , Operation > ( ) ) ;
221222 }
222223
223224 /// <summary>
@@ -226,19 +227,19 @@ private Tensor _try_guard_against_uninitialized_dependencies(Tensor initial_valu
226227 /// <param name="tensor">A `Tensor`. The tensor to replace.</param>
227228 /// <param name="op_cache">A dict mapping operation names to `Operation`s.</param>
228229 /// <returns>A `Tensor` compatible with `tensor`.</returns>
229- private Tensor _safe_initial_value_from_tensor ( Tensor tensor , Dictionary < string , Operation > op_cache )
230+ private Tensor _safe_initial_value_from_tensor ( string name , Tensor tensor , Dictionary < string , Operation > op_cache )
230231 {
231232 var op = tensor . op ;
232233 var new_op = op_cache . ContainsKey ( op . name ) ? op_cache [ op . name ] : null ;
233234 if ( new_op == null )
234235 {
235- new_op = _safe_initial_value_from_op ( op , op_cache ) ;
236+ new_op = _safe_initial_value_from_op ( name , op , op_cache ) ;
236237 op_cache [ op . name ] = new_op ;
237238 }
238239 return new_op . outputs [ tensor . value_index ] ;
239240 }
240241
241- private Operation _safe_initial_value_from_op ( Operation op , Dictionary < string , Operation > op_cache )
242+ private Operation _safe_initial_value_from_op ( string name , Operation op , Dictionary < string , Operation > op_cache )
242243 {
243244 var op_type = op . node_def . Op ;
244245 switch ( op_type )
@@ -250,13 +251,50 @@ private Operation _safe_initial_value_from_op(Operation op, Dictionary<string, O
250251 case "Variable" :
251252 case "VariableV2" :
252253 case "VarHandleOp" :
253- break ;
254+ var initialized_value = _find_initialized_value_for_variable ( op ) ;
255+ return initialized_value == null ? op : initialized_value . op ;
254256 }
255257
256258 // Recursively build initializer expressions for inputs.
259+ var modified = false ;
260+ var new_op_inputs = new List < Tensor > ( ) ;
261+ foreach ( var op_input in op . inputs )
262+ {
263+ var new_op_input = _safe_initial_value_from_tensor ( name , op_input as Tensor , op_cache ) ;
264+ new_op_inputs . Add ( new_op_input ) ;
265+ modified = modified || new_op_input != op_input ;
266+ }
267+
268+ // If at least one input was modified, replace the op.
269+ if ( modified )
270+ {
271+ var new_op_type = op_type ;
272+ if ( new_op_type == "RefSwitch" )
273+ new_op_type = "Switch" ;
274+ var new_op_name = op . node_def . Name + "_" + name ;
275+ new_op_name = new_op_name . Replace ( ":" , "_" ) ;
276+ var attrs = new Dictionary < string , AttrValue > ( ) ;
277+ attrs [ op . node_def . Name ] = op . node_def . Attr . ElementAt ( 0 ) . Value ;
278+ /*return op.graph.create_op(new_op_type, new_op_inputs.ToArray(), op._output_types,
279+ name: new_op_name, attrs: attrs);*/
280+ }
257281 return op ;
258282 }
259283
284+ private Operation _find_initialized_value_for_variable ( Operation variable_op )
285+ {
286+ var var_names = new [ ] { variable_op . node_def . Name , variable_op . node_def . Name + ":0" } ;
287+ foreach ( var collection_name in new [ ] { tf . GraphKeys . GLOBAL_VARIABLES ,
288+ tf . GraphKeys . LOCAL_VARIABLES } )
289+ {
290+ foreach ( var var in variable_op . graph . get_collection < RefVariable > ( collection_name ) )
291+ if ( var_names . Contains ( var . name ) )
292+ return var . initialized_value ( ) ;
293+ }
294+
295+ return null ;
296+ }
297+
260298 /// <summary>
261299 /// Assigns a new value to the variable.
262300 /// </summary>
@@ -318,6 +356,15 @@ private ITensorOrOperation read_value()
318356 return array_ops . identity ( _variable , name : "read" ) ;
319357 }
320358
359+ /// <summary>
360+ /// Returns the Tensor used as the initial value for the variable.
361+ /// </summary>
362+ /// <returns></returns>
363+ private ITensorOrOperation initial_value ( )
364+ {
365+ return _initial_value ;
366+ }
367+
321368 public Tensor is_variable_initialized ( RefVariable variable )
322369 {
323370 return state_ops . is_variable_initialized ( variable ) ;
@@ -326,10 +373,9 @@ public Tensor is_variable_initialized(RefVariable variable)
326373 public Tensor initialized_value ( )
327374 {
328375 ops . init_scope ( ) ;
329- throw new NotImplementedException ( "" ) ;
330- /*return control_flow_ops.cond(is_variable_initialized(this),
376+ return control_flow_ops . cond ( is_variable_initialized ( this ) ,
331377 read_value ,
332- () => initial_value);*/
378+ initial_value ) ;
333379 }
334380 }
335381}
0 commit comments