Skip to content

Commit fe76c9c

Browse files
committed
Add ResourceVariable native api.
1 parent e266030 commit fe76c9c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

65 files changed

+584
-345
lines changed

TensorFlow.NET.sln

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,51 +16,95 @@ EndProject
1616
Global
1717
GlobalSection(SolutionConfigurationPlatforms) = preSolution
1818
Debug|Any CPU = Debug|Any CPU
19+
Debug|x64 = Debug|x64
1920
Debug-Minimal|Any CPU = Debug-Minimal|Any CPU
21+
Debug-Minimal|x64 = Debug-Minimal|x64
2022
Publish|Any CPU = Publish|Any CPU
23+
Publish|x64 = Publish|x64
2124
Release|Any CPU = Release|Any CPU
25+
Release|x64 = Release|x64
2226
EndGlobalSection
2327
GlobalSection(ProjectConfigurationPlatforms) = postSolution
2428
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
2529
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|Any CPU.Build.0 = Debug|Any CPU
30+
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|x64.ActiveCfg = Debug|x64
31+
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|x64.Build.0 = Debug|x64
2632
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU
2733
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU
34+
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug-Minimal|x64.ActiveCfg = Debug|x64
35+
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug-Minimal|x64.Build.0 = Debug|x64
2836
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Publish|Any CPU.ActiveCfg = Release|Any CPU
2937
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Publish|Any CPU.Build.0 = Release|Any CPU
38+
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Publish|x64.ActiveCfg = Release|x64
39+
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Publish|x64.Build.0 = Release|x64
3040
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|Any CPU.ActiveCfg = Release|Any CPU
3141
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|Any CPU.Build.0 = Release|Any CPU
42+
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|x64.ActiveCfg = Release|x64
43+
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|x64.Build.0 = Release|x64
3244
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
3345
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug|Any CPU.Build.0 = Debug|Any CPU
46+
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug|x64.ActiveCfg = Debug|x64
47+
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug|x64.Build.0 = Debug|x64
3448
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU
3549
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU
50+
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug-Minimal|x64.ActiveCfg = Debug|x64
51+
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug-Minimal|x64.Build.0 = Debug|x64
3652
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Publish|Any CPU.ActiveCfg = Release|Any CPU
3753
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Publish|Any CPU.Build.0 = Release|Any CPU
54+
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Publish|x64.ActiveCfg = Release|x64
55+
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Publish|x64.Build.0 = Release|x64
3856
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Release|Any CPU.ActiveCfg = Release|Any CPU
3957
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Release|Any CPU.Build.0 = Release|Any CPU
58+
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Release|x64.ActiveCfg = Release|x64
59+
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Release|x64.Build.0 = Release|x64
4060
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
4161
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug|Any CPU.Build.0 = Debug|Any CPU
62+
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug|x64.ActiveCfg = Debug|x64
63+
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug|x64.Build.0 = Debug|x64
4264
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU
4365
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU
66+
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug-Minimal|x64.ActiveCfg = Debug|x64
67+
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug-Minimal|x64.Build.0 = Debug|x64
4468
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Publish|Any CPU.ActiveCfg = Release|Any CPU
4569
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Publish|Any CPU.Build.0 = Release|Any CPU
70+
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Publish|x64.ActiveCfg = Release|x64
71+
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Publish|x64.Build.0 = Release|x64
4672
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|Any CPU.ActiveCfg = Release|Any CPU
4773
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|Any CPU.Build.0 = Release|Any CPU
74+
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|x64.ActiveCfg = Release|x64
75+
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|x64.Build.0 = Release|x64
4876
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
4977
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug|Any CPU.Build.0 = Debug|Any CPU
78+
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug|x64.ActiveCfg = Debug|x64
79+
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug|x64.Build.0 = Debug|x64
5080
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU
5181
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU
82+
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug-Minimal|x64.ActiveCfg = Debug|x64
83+
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug-Minimal|x64.Build.0 = Debug|x64
5284
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Publish|Any CPU.ActiveCfg = Release|Any CPU
5385
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Publish|Any CPU.Build.0 = Release|Any CPU
86+
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Publish|x64.ActiveCfg = Release|x64
87+
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Publish|x64.Build.0 = Release|x64
5488
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Release|Any CPU.ActiveCfg = Release|Any CPU
5589
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Release|Any CPU.Build.0 = Release|Any CPU
90+
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Release|x64.ActiveCfg = Release|x64
91+
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Release|x64.Build.0 = Release|x64
5692
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
5793
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug|Any CPU.Build.0 = Debug|Any CPU
94+
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug|x64.ActiveCfg = Debug|x64
95+
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug|x64.Build.0 = Debug|x64
5896
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU
5997
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU
98+
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug-Minimal|x64.ActiveCfg = Debug|x64
99+
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug-Minimal|x64.Build.0 = Debug|x64
60100
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Publish|Any CPU.ActiveCfg = Release|Any CPU
61101
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Publish|Any CPU.Build.0 = Release|Any CPU
102+
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Publish|x64.ActiveCfg = Release|x64
103+
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Publish|x64.Build.0 = Release|x64
62104
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|Any CPU.ActiveCfg = Release|Any CPU
63105
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|Any CPU.Build.0 = Release|Any CPU
106+
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|x64.ActiveCfg = Release|x64
107+
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|x64.Build.0 = Release|x64
64108
EndGlobalSection
65109
GlobalSection(SolutionProperties) = preSolution
66110
HideSolutionNode = FALSE

src/TensorFlowNET.Core/APIs/tf.gradients.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ namespace Tensorflow
2020
{
2121
public partial class tensorflow
2222
{
23-
public GradientActor GradientTape()
24-
=> new GradientActor();
23+
public GradientTape GradientTape()
24+
=> new GradientTape();
2525

2626
public Tensor[] gradients(Tensor[] ys,
2727
Tensor[] xs,

src/TensorFlowNET.Core/APIs/tf.nn.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,8 @@ public Tensor relu(Tensor features, string name = null)
123123
=> gen_nn_ops.relu(features, name);
124124

125125
public Tensor[] fused_batch_norm(Tensor x,
126-
VariableV1 scale,
127-
VariableV1 offset,
126+
IVariableV1 scale,
127+
IVariableV1 offset,
128128
Tensor mean = null,
129129
Tensor variance = null,
130130
float epsilon = 0.001f,

src/TensorFlowNET.Core/APIs/tf.train.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ public Optimizer AdamOptimizer(Tensor learning_rate, string name = "Adam")
5050
public ExponentialMovingAverage ExponentialMovingAverage(float decay)
5151
=> new ExponentialMovingAverage(decay);
5252

53-
public Saver Saver(VariableV1[] var_list = null, int max_to_keep = 5)
53+
public Saver Saver(IVariableV1[] var_list = null, int max_to_keep = 5)
5454
=> new Saver(var_list: var_list, max_to_keep: max_to_keep);
5555

5656
public string write_graph(Graph graph, string logdir, string name, bool as_text = true)
@@ -68,7 +68,7 @@ public Saver import_meta_graph(string meta_graph_or_file,
6868
clear_devices,
6969
import_scope).Item1;
7070

71-
public (MetaGraphDef, Dictionary<string, VariableV1>) export_meta_graph(string filename = "",
71+
public (MetaGraphDef, Dictionary<string, IVariableV1>) export_meta_graph(string filename = "",
7272
bool as_text = false,
7373
bool clear_devices = false,
7474
bool clear_extraneous_savers = false,

src/TensorFlowNET.Core/APIs/tf.variable.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ namespace Tensorflow
2121
{
2222
public partial class tensorflow
2323
{
24-
public VariableV1[] global_variables(string scope = null)
24+
public IVariableV1[] global_variables(string scope = null)
2525
{
26-
return (ops.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope) as List<VariableV1>)
26+
return (ops.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope) as List<IVariableV1>)
2727
.ToArray();
2828
}
2929

@@ -33,7 +33,7 @@ public VariableV1[] global_variables(string scope = null)
3333
/// <param name="var_list">List of `Variable` objects to initialize.</param>
3434
/// <param name="name">Optional name for the returned operation.</param>
3535
/// <returns>An Op that run the initializers of all the specified variables.</returns>
36-
public Operation variables_initializer(VariableV1[] var_list, string name = "init")
36+
public Operation variables_initializer(IVariableV1[] var_list, string name = "init")
3737
=> variables.variables_initializer(var_list, name: name);
3838

3939
public Operation global_variables_initializer()
@@ -47,8 +47,8 @@ public Operation global_variables_initializer()
4747
/// </summary>
4848
/// <param name="scope"></param>
4949
/// <returns></returns>
50-
public VariableV1[] trainable_variables(string scope = null)
51-
=> (variables.trainable_variables() as List<VariableV1>).ToArray();
50+
public IVariableV1[] trainable_variables(string scope = null)
51+
=> (variables.trainable_variables() as List<IVariableV1>).ToArray();
5252

5353
public RefVariable get_variable(string name,
5454
TensorShape shape = null,

src/TensorFlowNET.Core/Eager/EagerOperation.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ public class EagerOperation : Operation
88
{
99
public int NumInputs;
1010
public Tensor[] Inputs { get; set; }
11+
public int[] SkipInputIndices { get; set; }
1112

1213
public EagerOperation() : base(IntPtr.Zero) { }
1314

src/TensorFlowNET.Core/Eager/c_api.eager.cs

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,17 @@ public partial class c_api
1111
public static extern void TFE_RegisterGradientFunction(_gradient_function_callback callbackPointer);
1212

1313
[UnmanagedFunctionPointer(CallingConvention.StdCall)]
14-
public delegate IntPtr _gradient_function_callback(string op_name, int num_inputs, IntPtr op_inputs, int num_attrs, int num_outputs, IntPtr output_grads);
14+
public delegate IntPtr _gradient_function_callback(string op_name,
15+
int num_inputs,
16+
IntPtr op_inputs,
17+
int num_attrs,
18+
int num_outputs,
19+
IntPtr output_grads,
20+
int num_skip_inputs,
21+
IntPtr skip_input_indices);
22+
23+
[DllImport(TensorFlowLibName)]
24+
public static extern IntPtr TFE_WrapGradientResult(IntPtr[] gradients, int num_gradients);
1525

1626
[DllImport(TensorFlowLibName)]
1727
public static extern IntPtr VSpace_Handle(VSpace_callback_Ones ones, VSpace_callback_AggregateGrads aggregate_grads);
@@ -373,11 +383,17 @@ public static extern IntPtr TFE_QuickExecute(IntPtr ctx,
373383
public static extern void TFE_TapeSetRemove(IntPtr tape);
374384

375385
[DllImport(TensorFlowLibName)]
376-
public static extern void TFE_TapeWatch(IntPtr tape, IntPtr tensor);
386+
public static extern void TFE_TapeWatch(IntPtr tape, IntPtr variable);
377387

378388
[DllImport(TensorFlowLibName)]
379389
public static extern void TFE_TapeVariableAccessed(IntPtr variable);
380-
390+
391+
[DllImport(TensorFlowLibName)]
392+
public static extern IntPtr TFE_TapeWatchedVariables(IntPtr tape);
393+
394+
[DllImport(TensorFlowLibName)]
395+
public static extern IntPtr ResourceVariable_Handle(IntPtr variable);
396+
381397
[DllImport(TensorFlowLibName)]
382398
public static extern IntPtr TFE_TapeGradient(IntPtr tape,
383399
IntPtr[] target, int target_size,

src/TensorFlowNET.Core/Framework/meta_graph.cs

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ public static MetaGraphDef read_meta_graph_file(string filename)
3535
return meta_graph_def;
3636
}
3737

38-
public static (Dictionary<string, VariableV1>, ITensorOrOperation[]) import_scoped_meta_graph_with_return_elements(MetaGraphDef meta_graph_or_file,
38+
public static (Dictionary<string, IVariableV1>, ITensorOrOperation[]) import_scoped_meta_graph_with_return_elements(MetaGraphDef meta_graph_or_file,
3939
bool clear_devices = false,
4040
string import_scope = "",
4141
Dictionary<string, Tensor> input_map = null,
@@ -77,7 +77,7 @@ public static (Dictionary<string, VariableV1>, ITensorOrOperation[]) import_scop
7777
return_elements: return_elements);
7878

7979
// Restores all the other collections.
80-
var variable_objects = new Dictionary<ByteString, VariableV1>();
80+
var variable_objects = new Dictionary<ByteString, IVariableV1>();
8181
foreach (var col in meta_graph_def.CollectionDef.OrderBy(x => x.Key))
8282
{
8383
// Don't add unbound_inputs to the new graph.
@@ -99,7 +99,7 @@ public static (Dictionary<string, VariableV1>, ITensorOrOperation[]) import_scop
9999
{
100100
foreach (var value in col.Value.BytesList.Value)
101101
{
102-
VariableV1 variable = null;
102+
IVariableV1 variable = null;
103103
if (!variable_objects.ContainsKey(value))
104104
{
105105
var proto = VariableDef.Parser.ParseFrom(value);
@@ -147,10 +147,10 @@ public static (Dictionary<string, VariableV1>, ITensorOrOperation[]) import_scop
147147
}
148148
}
149149

150-
var variables = graph.get_collection<VariableV1>(tf.GraphKeys.GLOBAL_VARIABLES,
150+
var variables = graph.get_collection<IVariableV1>(tf.GraphKeys.GLOBAL_VARIABLES,
151151
scope: scope_to_prepend_to_names);
152-
var var_list = new Dictionary<string, VariableV1>();
153-
variables.ForEach(v => var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v);
152+
var var_list = new Dictionary<string, IVariableV1>();
153+
variables.ForEach(v => var_list[ops.strip_name_scope(v.Name, scope_to_prepend_to_names)] = v);
154154

155155
return (var_list, imported_return_elements);
156156
}
@@ -168,7 +168,7 @@ public static (Dictionary<string, VariableV1>, ITensorOrOperation[]) import_scop
168168
/// <param name="strip_default_attrs"></param>
169169
/// <param name="meta_info_def"></param>
170170
/// <returns></returns>
171-
public static (MetaGraphDef, Dictionary<string, VariableV1>) export_scoped_meta_graph(string filename = "",
171+
public static (MetaGraphDef, Dictionary<string, IVariableV1>) export_scoped_meta_graph(string filename = "",
172172
GraphDef graph_def = null,
173173
bool as_text = false,
174174
string unbound_inputs_col_name = "unbound_inputs",
@@ -180,14 +180,14 @@ public static (MetaGraphDef, Dictionary<string, VariableV1>) export_scoped_meta_
180180
{
181181
var graph = ops.get_default_graph();
182182

183-
var var_list = new Dictionary<string, VariableV1>();
184-
var variables = graph.get_collection<VariableV1>(tf.GraphKeys.GLOBAL_VARIABLES);
183+
var var_list = new Dictionary<string, IVariableV1>();
184+
var variables = graph.get_collection<IVariableV1>(tf.GraphKeys.GLOBAL_VARIABLES);
185185

186186
if (variables != null)
187187
{
188188
foreach (var v in variables)
189189
{
190-
var_list[v.name] = v;
190+
var_list[v.Name] = v;
191191
}
192192
}
193193

@@ -268,7 +268,7 @@ private static void add_collection_def(MetaGraphDef meta_graph_def,
268268

269269
switch (graph.get_collection(key))
270270
{
271-
case List<VariableV1> collection_list:
271+
case List<IVariableV1> collection_list:
272272
col_def.BytesList = new Types.BytesList();
273273
foreach (var x in collection_list)
274274
{

0 commit comments

Comments
 (0)