Skip to content

Commit 7a706c9

Browse files
committed
fix unique_layer_name
1 parent 980201d commit 7a706c9

File tree

11 files changed

+204
-140
lines changed

11 files changed

+204
-140
lines changed

src/TensorFlowNET.Core/Graphs/Graph.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ public partial class Graph : IPython, IDisposable
6969
private List<Tensor> _unfeedable_tensors = new List<Tensor>();
7070

7171
public string _name_stack = "";
72-
public string _graph_key;
72+
private string _graph_key;
73+
public string graph_key => _graph_key;
7374
public string _last_loss_reduction;
7475

7576
public Status Status { get; }

src/TensorFlowNET.Core/Keras/Engine/Sequential.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ public void __enter__()
1919

2020
}
2121

22+
/// <summary>
23+
/// Adds a layer instance on top of the layer stack.
24+
/// </summary>
25+
/// <param name="layer"></param>
2226
public void add(Layer layer)
2327
{
2428
built = false;
@@ -32,7 +36,7 @@ public void add(Layer layer)
3236
var x = keras.layers.Input(
3337
batch_shape: batch_shape,
3438
dtype: dtype,
35-
name: layer._name + "_input");
39+
name: layer.name + "_input");
3640

3741
// This will build the current layer
3842
// and create the node connecting the current layer

src/TensorFlowNET.Core/Keras/Layers/Layer.cs

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using System.Text;
55
using Tensorflow.Keras.Engine;
66
using Tensorflow.Keras.Utils;
7+
using static Tensorflow.Python;
78

89
namespace Tensorflow.Keras.Layers
910
{
@@ -33,7 +34,8 @@ public class Layer : CheckpointableBase
3334
protected InputSpec input_spec;
3435
protected bool supports_masking;
3536
protected List<RefVariable> _trainable_weights;
36-
public string _name;
37+
private string _name;
38+
public string name => _name;
3739
protected string _base_name;
3840
protected bool _compute_previous_mask;
3941
protected List<Operation> _updates;
@@ -85,17 +87,24 @@ public Tensor __call__(Tensor[] inputs,
8587
// Handle Keras mask propagation from previous layer to current layer.
8688
Python.with(ops.name_scope(_name_scope()), delegate
8789
{
88-
if (!built)
90+
/*if (!built)
8991
{
9092
_maybe_build(inputs);
9193
built = true;
92-
}
94+
}*/
9395

9496
if (build_graph)
9597
{
9698
// Symbolic execution on symbolic tensors. We will attempt to build
9799
// the corresponding TF subgraph inside `backend.get_graph()`
98-
var graph = backend.get_graph();
100+
var graph = backend.get_graph().as_default();
101+
with(ops.name_scope(_name_scope()), delegate
102+
{
103+
// Build layer if applicable (if the `build` method has been
104+
// overridden).
105+
_maybe_build(inputs[0]);
106+
});
107+
99108
outputs = call(inputs[0], training: training);
100109
_handle_activity_regularization(inputs[0], outputs);
101110
_set_mask_metadata(inputs[0], outputs, null);
@@ -130,13 +139,17 @@ protected virtual Tensor call(Tensor inputs, Tensor training = null)
130139

131140
protected virtual string _name_scope()
132141
{
133-
return null;
142+
return name;
134143
}
135144

136-
protected void _maybe_build(Tensor[] inputs)
145+
protected void _maybe_build(Tensor input)
137146
{
138-
var input_list = inputs;
139-
build(input_list[0].GetShape());
147+
// Check input assumptions set before layer building, e.g. input rank.
148+
if (built)
149+
return;
150+
151+
build(input.GetShape());
152+
built = true;
140153
}
141154

142155
protected virtual void build(TensorShape input_shape)
@@ -160,7 +173,7 @@ protected virtual RefVariable add_weight(string name,
160173
var variable = _add_variable_with_custom_getter(name,
161174
shape,
162175
dtype: dtype,
163-
getter: getter == null ? base_layer_utils.make_variable : getter,
176+
//getter: getter == null ? base_layer_utils.make_variable : getter,
164177
overwrite: true,
165178
initializer: initializer,
166179
trainable: trainable.Value);
@@ -176,12 +189,12 @@ protected virtual void add_update(Tensor[] updates, bool inputs = false)
176189
_updates.AddRange(updates_op);
177190
}
178191

179-
protected virtual void _init_set_name(string name)
192+
protected virtual void _init_set_name(string name, bool zero_based = true)
180193
{
181-
string base_name = name;
182194
if (name == null)
183-
(_name, base_name) = _make_unique_name();
184-
_base_name = base_name;
195+
_name = base_layer_utils.unique_layer_name(generic_utils.to_snake_case(this.GetType().Name), zero_based: zero_based);
196+
else
197+
_name = name;
185198
}
186199

187200
protected virtual (string, string) _make_unique_name()

src/TensorFlowNET.Core/Keras/Sequence.cs

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -30,26 +30,6 @@ public NDArray pad_sequences(NDArray sequences,
3030
object value = null)
3131
{
3232
int[] length = new int[sequences.size];
33-
switch (sequences.dtype.Name)
34-
{
35-
case "Object":
36-
for (int i = 0; i < sequences.size; i++)
37-
{
38-
switch (sequences.Data<object>(i))
39-
{
40-
case string data:
41-
length[i] = Regex.Matches(data, ",").Count;
42-
break;
43-
}
44-
}
45-
break;
46-
case "Int32":
47-
for (int i = 0; i < sequences.size; i++)
48-
length[i] = Regex.Matches(sequences.Data<object>(i).ToString(), ",").Count;
49-
break;
50-
default:
51-
throw new NotImplementedException($"pad_sequences: {sequences.dtype.Name}");
52-
}
5333

5434
if (maxlen == null)
5535
maxlen = length.Max();
Lines changed: 68 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,96 @@
11
using System;
22
using System.Collections.Generic;
3+
using System.Linq;
34
using System.Text;
5+
using static Tensorflow.Python;
46

57
namespace Tensorflow.Keras.Utils
68
{
79
public class base_layer_utils
810
{
11+
/// <summary>
12+
/// Adds a new variable to the layer.
13+
/// </summary>
14+
/// <param name="name"></param>
15+
/// <param name="shape"></param>
16+
/// <param name="dtype"></param>
17+
/// <param name="initializer"></param>
18+
/// <param name="trainable"></param>
19+
/// <returns></returns>
920
public static RefVariable make_variable(string name,
1021
int[] shape,
1122
TF_DataType dtype = TF_DataType.TF_FLOAT,
1223
IInitializer initializer = null,
13-
bool trainable = false)
24+
bool trainable = true,
25+
bool use_resource = true)
1426
{
15-
throw new NotImplementedException("");
27+
var initializing_from_value = false;
28+
29+
ops.init_scope();
30+
31+
Func<Tensor> init_val = ()=> initializer.call(new TensorShape(shape), dtype: dtype);
32+
33+
var variable_dtype = dtype.as_base_dtype();
34+
var v = tf.Variable(init_val);
35+
36+
return v;
1637
}
1738

1839
/// <summary>
1940
/// Makes a layer name (or arbitrary string) unique within a TensorFlow graph.
2041
/// </summary>
2142
/// <param name="name"></param>
2243
/// <returns></returns>
23-
public static string unique_layer_name(string name)
44+
public static string unique_layer_name(string name, Dictionary<(string, string), int> name_uid_map = null,
45+
string[] avoid_names = null, string @namespace = "", bool zero_based = false)
2446
{
25-
int number = get_default_graph_uid_map();
26-
return $"{name}_{number}";
47+
if(name_uid_map == null)
48+
name_uid_map = get_default_graph_uid_map();
49+
if (avoid_names == null)
50+
avoid_names = new string[0];
51+
52+
string proposed_name = null;
53+
while(proposed_name == null || avoid_names.Contains(proposed_name))
54+
{
55+
var name_key = (@namespace, name);
56+
if (!name_uid_map.ContainsKey(name_key))
57+
name_uid_map[name_key] = 0;
58+
59+
if (zero_based)
60+
{
61+
int number = name_uid_map[name_key];
62+
if (number > 0)
63+
proposed_name = $"{name}_{number}";
64+
else
65+
proposed_name = name;
66+
67+
name_uid_map[name_key] += 1;
68+
}
69+
else
70+
{
71+
name_uid_map[name_key] += 1;
72+
proposed_name = $"{name}_{name_uid_map[name_key]}";
73+
}
74+
}
75+
76+
return proposed_name;
2777
}
2878

29-
public static int get_default_graph_uid_map()
79+
public static Dictionary<(string, string), int> get_default_graph_uid_map()
3080
{
3181
var graph = ops.get_default_graph();
32-
return graph._next_id();
82+
Dictionary<(string, string), int> name_uid_map = null;
83+
if (backend.PER_GRAPH_LAYER_NAME_UIDS.ContainsKey(graph.graph_key))
84+
{
85+
name_uid_map = backend.PER_GRAPH_LAYER_NAME_UIDS[graph.graph_key];
86+
}
87+
else
88+
{
89+
name_uid_map = new Dictionary<(string, string), int>();
90+
backend.PER_GRAPH_LAYER_NAME_UIDS[graph.graph_key] = name_uid_map;
91+
}
92+
93+
return name_uid_map;
3394
}
3495
}
3596
}

src/TensorFlowNET.Core/Keras/backend.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,13 @@ namespace Tensorflow.Keras
66
{
77
public class backend
88
{
9+
/// <summary>
10+
/// A global dictionary mapping graph objects to an index of counters used
11+
/// for various layer names in each graph.
12+
/// Allows to give unique autogenerated names to layers, in a graph-specific way.
13+
/// </summary>
14+
public static Dictionary<string, Dictionary<(string, string), int>> PER_GRAPH_LAYER_NAME_UIDS = new Dictionary<string, Dictionary<(string, string), int>>();
15+
916
public static void track_variable(RefVariable v)
1017
{
1118

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,12 @@
11
using System;
22
using System.Collections.Generic;
33
using System.Text;
4+
using Tensorflow.Train;
45

56
namespace Tensorflow
67
{
7-
public abstract class CheckpointableBase
8+
public abstract class CheckpointableBase : Trackable
89
{
9-
/// <summary>
10-
/// Restore-on-create for a variable be saved with this `Checkpointable`.
11-
/// </summary>
12-
/// <returns></returns>
13-
protected virtual RefVariable _add_variable_with_custom_getter(string name,
14-
int[] shape,
15-
TF_DataType dtype = TF_DataType.TF_FLOAT,
16-
IInitializer initializer = null,
17-
Func<string, int[], TF_DataType, IInitializer, bool, RefVariable> getter = null,
18-
bool overwrite = false,
19-
bool trainable = false)
20-
{
21-
var new_variable = getter(name, shape, dtype, initializer, trainable);
22-
if (!overwrite || new_variable is RefVariable)
23-
return _track_checkpointable(new_variable, name: name,
24-
overwrite: overwrite);
25-
else
26-
return new_variable;
27-
}
2810

29-
protected RefVariable _track_checkpointable(RefVariable checkpointable, string name, bool overwrite = false)
30-
{
31-
return checkpointable;
32-
}
3311
}
3412
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Train
6+
{
7+
public abstract class Trackable
8+
{
9+
/// <summary>
10+
/// Restore-on-create for a variable be saved with this `Checkpointable`.
11+
/// </summary>
12+
/// <returns></returns>
13+
protected virtual RefVariable _add_variable_with_custom_getter(string name,
14+
int[] shape,
15+
TF_DataType dtype = TF_DataType.TF_FLOAT,
16+
IInitializer initializer = null,
17+
Func<string, int[], TF_DataType, IInitializer, bool, RefVariable> getter = null,
18+
bool overwrite = false,
19+
bool trainable = false)
20+
{
21+
var new_variable = getter(name, shape, dtype, initializer, trainable);
22+
if (!overwrite || new_variable is RefVariable)
23+
return _track_checkpointable(new_variable, name: name,
24+
overwrite: overwrite);
25+
else
26+
return new_variable;
27+
}
28+
29+
protected RefVariable _track_checkpointable(RefVariable checkpointable, string name, bool overwrite = false)
30+
{
31+
return checkpointable;
32+
}
33+
}
34+
}

src/TensorFlowNET.Core/Variables/RefVariable.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ private void _init_from_args(object initial_value,
111111

112112
// Store the graph key so optimizers know how to only retrieve variables from
113113
// this graph.
114-
_graph_key = ops.get_default_graph()._graph_key;
114+
_graph_key = ops.get_default_graph().graph_key;
115115

116116
_trainable = trainable;
117117
if (trainable && !collections.Contains(ops.GraphKeys.TRAINABLE_VARIABLES))

0 commit comments

Comments
 (0)