forked from SciSharp/TensorFlow.NET
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathModel.cs
More file actions
203 lines (177 loc) · 6.96 KB
/
Model.cs
File metadata and controls
203 lines (177 loc) · 6.96 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
using System.Diagnostics;
using Tensorflow.Common.Types;
using Tensorflow.Framework.Models;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Losses;
using Tensorflow.Keras.Saving.SavedModel;
using Tensorflow.Keras.Utils;
using Tensorflow.Train;
using Tensorflow.Util;
namespace Tensorflow.Keras.Engine
{
/// <summary>
/// `Model` groups layers into an object with training and inference features.
/// </summary>
public partial class Model : Layer, IModel
{
#pragma warning disable CS0169 // The field 'Model._cloning' is never used
bool _cloning;
#pragma warning restore CS0169 // The field 'Model._cloning' is never used
#pragma warning disable CS0108 // Member hides inherited member; missing new keyword
#pragma warning disable CS0414 // The field 'Model._is_compiled' is assigned but its value is never used
bool _is_compiled;
#pragma warning restore CS0414 // The field 'Model._is_compiled' is assigned but its value is never used
#pragma warning restore CS0108 // Member hides inherited member; missing new keyword
ILossFunc loss;
IOptimizer optimizer;
IVariableV1 _steps_per_execution;
protected bool _is_graph_network;
public Tensors inputs;
protected Tensors outputs;
protected List<string> input_names;
public string[] output_names;
IVariableV1 _train_counter;
IVariableV1 _test_counter;
IVariableV1 _predict_counter;
bool _base_model_initialized;
bool stop_training;
TensorSpec _saved_model_inputs_spec;
public bool IsGraphNetwork => _is_graph_network;
public IOptimizer Optimizer
{
get => optimizer;
set => optimizer = value;
}
public bool Stop_training
{
get => stop_training;
set => stop_training = value;
}
public Model(ModelArgs args)
: base(args)
{
_init_batch_counters();
}
public void _set_inputs(TensorSpec inputs)
{
_set_save_spec(inputs);
}
internal void _set_save_spec(TensorSpec inputs)
{
if(_saved_model_inputs_spec is not null)
{
return;
}
var input_names = this.input_names;
if(input_names is null || input_names.Count == 0)
{
input_names = compile_utils.create_pseudo_input_names(inputs);
}
var flat_inputs = nest.flatten(inputs);
List<TensorSpec> specs = new();
foreach(var (name, tensor) in zip(input_names, flat_inputs))
{
specs.Add(tf_utils.get_tensor_spec(tensor, dynamic_batch: false, name: name));
}
var packed_specs = nest.pack_sequence_as(inputs, specs) as TensorSpec;
Debug.Assert(specs is not null);
_saved_model_inputs_spec = packed_specs;
if(this is Sequential && _buildInputShape is null)
{
_buildInputShape = nest.map_structure<TensorSpec, TensorShapeConfig>(x => x is null ? null : x.shape, packed_specs);
}
}
internal override void Initialize(LayerArgs args)
{
_init_batch_counters();
base.Initialize(args);
}
void _configure_steps_per_execution(int steps_per_execution)
{
_steps_per_execution = tf.Variable(steps_per_execution,
dtype: TF_DataType.TF_INT64,
aggregation: VariableAggregation.OnlyFirstReplica);
}
void _reset_compile_cache()
{
// Used to cache `trainable` attr of `Layer`s for `fit`.
_compiled_trainable_state = _get_trainable_state();
keras.backend._GRAPH = null;
}
void _init_batch_counters()
{
_train_counter = tf.Variable(0L,
dtype: TF_DataType.TF_INT64,
aggregation: VariableAggregation.OnlyFirstReplica);
_test_counter = tf.Variable(0L,
dtype: TF_DataType.TF_INT64,
aggregation: VariableAggregation.OnlyFirstReplica);
_predict_counter = tf.Variable(0L,
dtype: TF_DataType.TF_INT64,
aggregation: VariableAggregation.OnlyFirstReplica);
}
public override List<ILayer> Layers
=> _flatten_layers(recursive: false, include_self: false).ToList();
public override List<IVariableV1> TrainableWeights
{
get
{
// skip the assertion of weights created.
var variables = new List<IVariableV1>();
if (!Trainable)
{
return variables;
}
foreach (var trackable_obj in _self_tracked_trackables)
{
if (trackable_obj.Trainable)
variables.AddRange(trackable_obj.TrainableWeights);
}
variables.AddRange(_trainable_weights);
return variables.Distinct().ToList();
}
}
public override List<IVariableV1> NonTrainableWeights
{
get
{
// skip the assertion of weights created.
var variables = new List<IVariableV1>();
foreach (var trackable_obj in _self_tracked_trackables)
{
variables.AddRange(trackable_obj.NonTrainableWeights);
}
if (!Trainable)
{
var trainable_variables = new List<IVariableV1>();
foreach (var trackable_obj in _self_tracked_trackables)
{
variables.AddRange(trackable_obj.TrainableWeights);
}
variables.AddRange(trainable_variables);
variables.AddRange(_trainable_weights);
variables.AddRange(_non_trainable_weights);
}
return variables.Distinct().ToList();
}
}
public override IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null)
{
if(save_type == SaveType.SAVEDMODEL)
{
//TODO: deal with `train_function`, `test_function`, `predict_function`, `train_tf_function`.
}
var children = base._trackable_children(save_type, cache);
return children;
}
public override void SetAttr(string name, object value)
{
// TODO(Rinne): deal with "_self_setattr_tracking".
//if(nest.flatten(value).All(v => v is Layer or IVariableV1 || base_layer_utils.has_weights(v)))
//{
// this._base_model_initialized;
//}
base.SetAttr(name, value);
}
}
}