Skip to content

Commit 1a025e6

Browse files
committed
add moving_averages, fix ExponentialMovingAverage
1 parent a4bede5 commit 1a025e6

File tree

15 files changed

+261
-34
lines changed

15 files changed

+261
-34
lines changed

TensorFlow.NET.sln

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Examples", "t
99
EndProject
1010
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Core", "src\TensorFlowNET.Core\TensorFlowNET.Core.csproj", "{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}"
1111
EndProject
12-
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Keras.Core", "src\KerasNET.Core\Keras.Core.csproj", "{902E188F-A953-43B4-9991-72BAB1697BC3}"
13-
EndProject
1412
Project("{6EC3EE1D-3C4E-46DD-8F32-0CC8E7565705}") = "TensorFlowNET.Examples.FSharp", "test\TensorFlowNET.Examples.FSharp\TensorFlowNET.Examples.FSharp.fsproj", "{62BC3801-F0D3-44A9-A0AC-712F40C8F961}"
1513
EndProject
1614
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowBenchmark", "src\TensorFlowNet.Benchmarks\TensorFlowBenchmark.csproj", "{68861442-971A-4196-876E-C9330F0B3C54}"
@@ -41,10 +39,6 @@ Global
4139
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|Any CPU.Build.0 = Debug|Any CPU
4240
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|Any CPU.ActiveCfg = Release|Any CPU
4341
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|Any CPU.Build.0 = Release|Any CPU
44-
{902E188F-A953-43B4-9991-72BAB1697BC3}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
45-
{902E188F-A953-43B4-9991-72BAB1697BC3}.Debug|Any CPU.Build.0 = Debug|Any CPU
46-
{902E188F-A953-43B4-9991-72BAB1697BC3}.Release|Any CPU.ActiveCfg = Release|Any CPU
47-
{902E188F-A953-43B4-9991-72BAB1697BC3}.Release|Any CPU.Build.0 = Release|Any CPU
4842
{62BC3801-F0D3-44A9-A0AC-712F40C8F961}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
4943
{62BC3801-F0D3-44A9-A0AC-712F40C8F961}.Debug|Any CPU.Build.0 = Debug|Any CPU
5044
{62BC3801-F0D3-44A9-A0AC-712F40C8F961}.Release|Any CPU.ActiveCfg = Release|Any CPU

src/TensorFlowNET.Core/Graphs/Graph.cs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,20 @@ public object get_collection(string name, string scope = null)
420420

421421
public List<T> get_collection<T>(string name, string scope = null)
422422
{
423-
return _collections.ContainsKey(name) ? _collections[name] as List<T> : new List<T>();
423+
List<T> t = default;
424+
var collection = _collections.ContainsKey(name) ? _collections[name] : new List<T>();
425+
switch (collection)
426+
{
427+
case List<VariableV1> list:
428+
t = list.Select(x => (T)(object)x).ToList();
429+
break;
430+
case List<RefVariable> list:
431+
t = list.Select(x => (T)(object)x).ToList();
432+
break;
433+
default:
434+
throw new NotImplementedException($"get_collection<{typeof(T).FullName}>");
435+
}
436+
return t;
424437
}
425438

426439
public object get_collection_ref(string name)

src/TensorFlowNET.Core/Operations/Operation.Output.cs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License.
1717
using System;
1818
using System.Linq;
1919
using System.Runtime.InteropServices;
20+
using static Tensorflow.Binding;
2021

2122
namespace Tensorflow
2223
{
@@ -48,6 +49,20 @@ public int OutputListLength(string name)
4849

4950
public TF_Output this[int index] => _tf_output(index);
5051

52+
/// <summary>
53+
/// List this operation's output types.
54+
/// </summary>
55+
public TF_DataType[] _output_types
56+
{
57+
get
58+
{
59+
var output_types = range(NumOutputs)
60+
.Select(i => OutputType(i))
61+
.ToArray();
62+
return output_types;
63+
}
64+
}
65+
5166
public unsafe TF_Input[] OutputConsumers(int index, int max_consumers)
5267
{
5368
var handle = Marshal.AllocHGlobal(Marshal.SizeOf<TF_Input>());

src/TensorFlowNET.Core/Operations/c_api.ops.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ public partial class c_api
198198
/// <param name="max_consumers">int</param>
199199
/// <returns></returns>
200200
[DllImport(TensorFlowLibName)]
201-
public static extern unsafe int TF_OperationOutputConsumers(TF_Output oper_out, IntPtr consumers, int max_consumers);
201+
public static extern int TF_OperationOutputConsumers(TF_Output oper_out, IntPtr consumers, int max_consumers);
202202

203203
[DllImport(TensorFlowLibName)]
204204
public static extern TF_DataType TF_OperationOutputType(TF_Output oper_out);

src/TensorFlowNET.Core/Train/ExponentialMovingAverage.cs

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ public class ExponentialMovingAverage
1313
bool _zero_debias;
1414
string _name;
1515
public string name => _name;
16-
List<VariableV1> _averages;
16+
Dictionary<RefVariable, RefVariable> _averages;
1717

1818
public ExponentialMovingAverage(float decay, int? num_updates = null, bool zero_debias = false,
1919
string name = "ExponentialMovingAverage")
@@ -22,7 +22,7 @@ public ExponentialMovingAverage(float decay, int? num_updates = null, bool zero_
2222
_num_updates = num_updates;
2323
_zero_debias = zero_debias;
2424
_name = name;
25-
_averages = new List<VariableV1>();
25+
_averages = new Dictionary<RefVariable, RefVariable>();
2626
}
2727

2828
/// <summary>
@@ -37,16 +37,38 @@ public Operation apply(RefVariable[] var_list = null)
3737

3838
foreach(var var in var_list)
3939
{
40-
if (!_averages.Contains(var))
40+
if (!_averages.ContainsKey(var))
4141
{
4242
ops.init_scope();
43-
var slot = new SlotCreator();
44-
var.initialized_value();
45-
// var avg = slot.create_zeros_slot
43+
var slot_creator = new SlotCreator();
44+
var value = var.initialized_value();
45+
var avg = slot_creator.create_slot(var,
46+
value,
47+
name,
48+
colocate_with_primary: true);
49+
ops.add_to_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, var);
50+
_averages[var] = avg;
4651
}
4752
}
4853

49-
throw new NotImplementedException("");
54+
return tf_with(ops.name_scope(name), scope =>
55+
{
56+
var decay = ops.convert_to_tensor(_decay, name: "decay");
57+
if (_num_updates.HasValue)
58+
{
59+
throw new NotImplementedException("ExponentialMovingAverage.apply");
60+
}
61+
62+
var updates = new List<Tensor>();
63+
foreach (var var in var_list)
64+
{
65+
var zero_debias = false;// _averages[var] in zero_debias_true
66+
var ama = moving_averages.assign_moving_average(_averages[var], var, decay, zero_debias: zero_debias);
67+
updates.Add(ama);
68+
}
69+
70+
return control_flow_ops.group(updates.ToArray(), name: scope);
71+
});
5072
}
5173
}
5274
}

src/TensorFlowNET.Core/Train/SlotCreator.cs

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,24 @@ namespace Tensorflow.Train
2222
{
2323
public class SlotCreator
2424
{
25+
/// <summary>
26+
/// Create a slot initialized to the given value.
27+
/// </summary>
28+
/// <param name="primary"></param>
29+
/// <param name="val"></param>
30+
/// <param name="name"></param>
31+
/// <param name="colocate_with_primary"></param>
32+
/// <returns></returns>
33+
public RefVariable create_slot(RefVariable primary, Tensor val, string name, bool colocate_with_primary = true)
34+
{
35+
var validate_shape = val.TensorShape.is_fully_defined();
36+
var prefix = primary.op.name;
37+
return tf_with(tf.variable_scope(name: null, prefix + "/" + name), delegate
38+
{
39+
return _create_slot_var(primary, val, "", validate_shape, null, TF_DataType.DtInvalid);
40+
});
41+
}
42+
2543
/// <summary>
2644
/// Create a slot initialized to 0 with same shape as the primary object.
2745
/// </summary>
@@ -73,7 +91,7 @@ public RefVariable create_slot_with_initializer(RefVariable primary, IInitialize
7391
/// <param name="shape"></param>
7492
/// <param name="dtype"></param>
7593
/// <returns></returns>
76-
private RefVariable _create_slot_var(VariableV1 primary, IInitializer val, string scope, bool validate_shape,
94+
private RefVariable _create_slot_var(VariableV1 primary, object val, string scope, bool validate_shape,
7795
TensorShape shape, TF_DataType dtype)
7896
{
7997
bool use_resource = primary is ResourceVariable;
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using static Tensorflow.Binding;
5+
6+
namespace Tensorflow.Train
7+
{
8+
public class moving_averages
9+
{
10+
/// <summary>
11+
/// Compute the moving average of a variable.
12+
/// </summary>
13+
/// <param name="variable"></param>
14+
/// <param name="value"></param>
15+
/// <param name="decay"></param>
16+
/// <param name="zero_debias"></param>
17+
/// <param name="name"></param>
18+
/// <returns></returns>
19+
public static Tensor assign_moving_average(RefVariable variable, RefVariable value, Tensor decay,
20+
bool zero_debias = true, string name = null)
21+
{
22+
tf_with(ops.name_scope(name, "", new { variable, value, decay }), scope =>
23+
{
24+
decay = ops.convert_to_tensor(1.0f - decay, name: "decay");
25+
if (decay.dtype != variable.dtype.as_base_dtype())
26+
decay = math_ops.cast(decay, variable.dtype.as_base_dtype());
27+
});
28+
29+
throw new NotImplementedException("assign_moving_average");
30+
}
31+
}
32+
}

src/TensorFlowNET.Core/Variables/RefVariable.cs

Lines changed: 56 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License.
1717
using Google.Protobuf;
1818
using System;
1919
using System.Collections.Generic;
20+
using System.Linq;
2021
using static Tensorflow.Binding;
2122

2223
namespace Tensorflow
@@ -176,7 +177,7 @@ private void _init_from_args(object initial_value,
176177
// If 'initial_value' makes use of other variables, make sure we don't
177178
// have an issue if these other variables aren't initialized first by
178179
// using their initialized_value() method.
179-
var _initial_value2 = _try_guard_against_uninitialized_dependencies(_initial_value);
180+
var _initial_value2 = _try_guard_against_uninitialized_dependencies(name, _initial_value);
180181

181182
_initializer_op = gen_state_ops.assign(_variable, _initial_value2, validate_shape).op;
182183

@@ -215,9 +216,9 @@ public Tensor _TensorConversionFunction(TF_DataType dtype = TF_DataType.DtInvali
215216
/// Attempt to guard against dependencies on uninitialized variables.
216217
/// </summary>
217218
/// <param name="initial_value"></param>
218-
private Tensor _try_guard_against_uninitialized_dependencies(Tensor initial_value)
219+
private Tensor _try_guard_against_uninitialized_dependencies(string name, Tensor initial_value)
219220
{
220-
return _safe_initial_value_from_tensor(initial_value, new Dictionary<string, Operation>());
221+
return _safe_initial_value_from_tensor(name, initial_value, op_cache: new Dictionary<string, Operation>());
221222
}
222223

223224
/// <summary>
@@ -226,19 +227,19 @@ private Tensor _try_guard_against_uninitialized_dependencies(Tensor initial_valu
226227
/// <param name="tensor">A `Tensor`. The tensor to replace.</param>
227228
/// <param name="op_cache">A dict mapping operation names to `Operation`s.</param>
228229
/// <returns>A `Tensor` compatible with `tensor`.</returns>
229-
private Tensor _safe_initial_value_from_tensor(Tensor tensor, Dictionary<string, Operation> op_cache)
230+
private Tensor _safe_initial_value_from_tensor(string name, Tensor tensor, Dictionary<string, Operation> op_cache)
230231
{
231232
var op = tensor.op;
232233
var new_op = op_cache.ContainsKey(op.name) ? op_cache[op.name] : null;
233234
if(new_op == null)
234235
{
235-
new_op = _safe_initial_value_from_op(op, op_cache);
236+
new_op = _safe_initial_value_from_op(name, op, op_cache);
236237
op_cache[op.name] = new_op;
237238
}
238239
return new_op.outputs[tensor.value_index];
239240
}
240241

241-
private Operation _safe_initial_value_from_op(Operation op, Dictionary<string, Operation> op_cache)
242+
private Operation _safe_initial_value_from_op(string name, Operation op, Dictionary<string, Operation> op_cache)
242243
{
243244
var op_type = op.node_def.Op;
244245
switch (op_type)
@@ -250,13 +251,50 @@ private Operation _safe_initial_value_from_op(Operation op, Dictionary<string, O
250251
case "Variable":
251252
case "VariableV2":
252253
case "VarHandleOp":
253-
break;
254+
var initialized_value = _find_initialized_value_for_variable(op);
255+
return initialized_value == null ? op : initialized_value.op;
254256
}
255257

256258
// Recursively build initializer expressions for inputs.
259+
var modified = false;
260+
var new_op_inputs = new List<Tensor>();
261+
foreach (var op_input in op.inputs)
262+
{
263+
var new_op_input = _safe_initial_value_from_tensor(name, op_input as Tensor, op_cache);
264+
new_op_inputs.Add(new_op_input);
265+
modified = modified || new_op_input != op_input;
266+
}
267+
268+
// If at least one input was modified, replace the op.
269+
if (modified)
270+
{
271+
var new_op_type = op_type;
272+
if (new_op_type == "RefSwitch")
273+
new_op_type = "Switch";
274+
var new_op_name = op.node_def.Name + "_" + name;
275+
new_op_name = new_op_name.Replace(":", "_");
276+
var attrs = new Dictionary<string, AttrValue>();
277+
attrs[op.node_def.Name] = op.node_def.Attr.ElementAt(0).Value;
278+
/*return op.graph.create_op(new_op_type, new_op_inputs.ToArray(), op._output_types,
279+
name: new_op_name, attrs: attrs);*/
280+
}
257281
return op;
258282
}
259283

284+
private Operation _find_initialized_value_for_variable(Operation variable_op)
285+
{
286+
var var_names = new[] { variable_op.node_def.Name, variable_op.node_def.Name + ":0" };
287+
foreach(var collection_name in new[]{tf.GraphKeys.GLOBAL_VARIABLES,
288+
tf.GraphKeys.LOCAL_VARIABLES })
289+
{
290+
foreach (var var in variable_op.graph.get_collection<RefVariable>(collection_name))
291+
if (var_names.Contains(var.name))
292+
return var.initialized_value();
293+
}
294+
295+
return null;
296+
}
297+
260298
/// <summary>
261299
/// Assigns a new value to the variable.
262300
/// </summary>
@@ -318,6 +356,15 @@ private ITensorOrOperation read_value()
318356
return array_ops.identity(_variable, name: "read");
319357
}
320358

359+
/// <summary>
360+
/// Returns the Tensor used as the initial value for the variable.
361+
/// </summary>
362+
/// <returns></returns>
363+
private ITensorOrOperation initial_value()
364+
{
365+
return _initial_value;
366+
}
367+
321368
public Tensor is_variable_initialized(RefVariable variable)
322369
{
323370
return state_ops.is_variable_initialized(variable);
@@ -326,10 +373,9 @@ public Tensor is_variable_initialized(RefVariable variable)
326373
public Tensor initialized_value()
327374
{
328375
ops.init_scope();
329-
throw new NotImplementedException("");
330-
/*return control_flow_ops.cond(is_variable_initialized(this),
376+
return control_flow_ops.cond(is_variable_initialized(this),
331377
read_value,
332-
() => initial_value);*/
378+
initial_value);
333379
}
334380
}
335381
}

src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,8 @@ public static Tensor scatter_add(RefVariable @ref, Tensor indices, Tensor update
149149

150150
public static Tensor is_variable_initialized(RefVariable @ref, string name = null)
151151
{
152-
throw new NotImplementedException("");
152+
var _op = _op_def_lib._apply_op_helper("IsVariableInitialized", name: name, args: new { @ref });
153+
return _op.output;
153154
}
154155
}
155156
}

src/TensorFlowNET.Core/ops.GraphKeys.cs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ public class GraphKeys
5252
/// </summary>
5353
public const string LOSSES_ = "losses";
5454

55+
public const string MOVING_AVERAGE_VARIABLES = "moving_average_variables";
56+
5557
/// <summary>
5658
/// Key to collect Variable objects that are global (shared across machines).
5759
/// Default collection for all variables, except local ones.
@@ -100,6 +102,12 @@ public class GraphKeys
100102
/// </summary>
101103
public string _STREAMING_MODEL_PORTS => _STREAMING_MODEL_PORTS_;
102104

105+
/// <summary>
106+
/// Key to collect local variables that are local to the machine and are not
107+
/// saved/restored.
108+
/// </summary>
109+
public string LOCAL_VARIABLES = "local_variables";
110+
103111
/// <summary>
104112
/// Key to collect losses
105113
/// </summary>
@@ -109,7 +117,7 @@ public class GraphKeys
109117
/// Key to collect Variable objects that are global (shared across machines).
110118
/// Default collection for all variables, except local ones.
111119
/// </summary>
112-
public string GLOBAL_VARIABLES => GLOBAL_VARIABLES_;
120+
public string GLOBAL_VARIABLES = GLOBAL_VARIABLES_;
113121

114122
public string TRAIN_OP => TRAIN_OP_;
115123

0 commit comments

Comments
 (0)