forked from SciSharp/TensorFlow.NET
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathKerasLayer.cs
More file actions
158 lines (140 loc) · 5.43 KB
/
KerasLayer.cs
File metadata and controls
158 lines (140 loc) · 5.43 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Common.Types;
using Tensorflow.Keras.Engine;
using Tensorflow.Train;
using Tensorflow.Training;
using Tensorflow.Training.Saving.SavedModel;
using static Tensorflow.Binding;
namespace Tensorflow.Hub
{
public class KerasLayer : Layer
{
private string _handle;
private LoadOptions? _load_options;
private Trackable _func;
private Func<Tensors, Tensors> _callable;
public KerasLayer(string handle, bool trainable = false, LoadOptions? load_options = null) :
base(new Keras.ArgsDefinition.LayerArgs() { Trainable = trainable })
{
_handle = handle;
_load_options = load_options;
_func = load_module(_handle, _load_options);
_track_trackable(_func, "_func");
// TODO(Rinne): deal with _is_hub_module_v1.
_callable = _get_callable();
_setup_layer(trainable);
}
private void _setup_layer(bool trainable = false)
{
HashSet<string> trainable_variables;
if (_func is Layer layer)
{
foreach (var v in layer.TrainableVariables)
{
_add_existing_weight(v, true);
}
trainable_variables = new HashSet<string>(layer.TrainableVariables.Select(v => v.UniqueId));
}
else if (_func.CustomizedFields.TryGetValue("trainable_variables", out var obj) && obj is IEnumerable<Trackable> trackables)
{
foreach (var trackable in trackables)
{
if (trackable is IVariableV1 v)
{
_add_existing_weight(v, true);
}
}
trainable_variables = new HashSet<string>(trackables.Where(t => t is IVariableV1).Select(t => (t as IVariableV1).UniqueId));
}
else
{
trainable_variables = new HashSet<string>();
}
if (_func is Layer)
{
layer = (Layer)_func;
foreach (var v in layer.Variables)
{
if (!trainable_variables.Contains(v.UniqueId))
{
_add_existing_weight(v, false);
}
}
}
else if (_func.CustomizedFields.TryGetValue("variables", out var obj) && obj is IEnumerable<Trackable> total_trackables)
{
foreach (var trackable in total_trackables)
{
if (trackable is IVariableV1 v && !trainable_variables.Contains(v.UniqueId))
{
_add_existing_weight(v, false);
}
}
}
if (_func.CustomizedFields.ContainsKey("regularization_losses"))
{
if ((_func.CustomizedFields["regularization_losses"] as ListWrapper)?.Count > 0)
{
throw new NotImplementedException("The regularization_losses loading has not been supported yet, " +
"please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues to let us know and add a feature.");
}
}
}
protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optionalArgs = null)
{
_check_trainability();
// TODO(Rinne): deal with training_argument
var result = _callable(inputs);
return _apply_output_shape_if_set(inputs, result);
}
private void _check_trainability()
{
if (!Trainable) return;
// TODO(Rinne): deal with _is_hub_module_v1 and signature
if (TrainableWeights is null || TrainableWeights.Count == 0)
{
tf.Logger.Error("hub.KerasLayer is trainable but has zero trainable weights.");
}
}
private Tensors _apply_output_shape_if_set(Tensors inputs, Tensors result)
{
// TODO(Rinne): implement it.
return result;
}
private void _add_existing_weight(IVariableV1 weight, bool? trainable = null)
{
bool is_trainable;
if (trainable is null)
{
is_trainable = weight.Trainable;
}
else
{
is_trainable = trainable.Value;
}
add_weight(weight.Name, weight.shape, weight.dtype, trainable: is_trainable, getter: x => weight);
}
private Func<Tensors, Tensors> _get_callable()
{
if (_func is Layer layer)
{
return x => layer.Apply(x);
}
if (_func.CustomizedFields.ContainsKey("__call__"))
{
if (_func.CustomizedFields["__call__"] is RestoredFunction function)
{
return x => function.Apply(x);
}
}
throw new ValueError("Cannot get the callable from the model.");
}
private static Trackable load_module(string handle, LoadOptions? load_options = null)
{
//var set_load_options = load_options ?? LoadContext.get_load_option();
return module_v2.load(handle, load_options);
}
}
}