forked from SciSharp/TensorFlow.NET
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathVariableTest.cs
More file actions
152 lines (132 loc) · 4.45 KB
/
VariableTest.cs
File metadata and controls
152 lines (132 loc) · 4.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
using FluentAssertions;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using NumSharp;
using Tensorflow;
using static Tensorflow.Binding;
namespace TensorFlowNET.UnitTest
{
[TestClass]
public class VariableTest
{
[TestMethod]
public void Initializer()
{
var x = tf.Variable(10, name: "x");
using (var session = tf.Session())
{
session.run(x.initializer);
var result = session.run(x);
Assert.AreEqual(10, (int)result);
}
}
[TestMethod]
public void StringVar()
{
var mammal1 = tf.Variable("Elephant", name: "var1", dtype: tf.@string);
var mammal2 = tf.Variable("Tiger");
}
/// <summary>
/// https://www.tensorflow.org/api_docs/python/tf/variable_scope
/// how to create a new variable
/// </summary>
[TestMethod]
public void VarCreation()
{
tf.Graph().as_default();
tf_with(tf.variable_scope("foo"), delegate
{
tf_with(tf.variable_scope("bar"), delegate
{
var v = tf.get_variable("v", new TensorShape(1));
v.name.Should().Be("foo/bar/v:0");
});
});
}
/// <summary>
/// how to reenter a premade variable scope safely
/// </summary>
[TestMethod]
public void ReenterVariableScope()
{
tf.Graph().as_default();
variable_scope vs = null;
tf_with(tf.variable_scope("foo"), v => vs = v);
// Re-enter the variable scope.
tf_with(tf.variable_scope(vs, auxiliary_name_scope: false), v =>
{
var vs1 = (VariableScope)v;
// Restore the original name_scope.
tf_with(tf.name_scope(vs1.original_name_scope), delegate
{
var v1 = tf.get_variable("v", new TensorShape(1));
Assert.AreEqual(v1.name, "foo/v:0");
var c1 = tf.constant(new int[] { 1 }, name: "c");
Assert.AreEqual(c1.name, "foo/c:0");
});
});
}
[TestMethod]
public void ScalarVar()
{
var x = tf.constant(3, name: "x");
var y = tf.Variable(x + 1, name: "y");
var model = tf.global_variables_initializer();
using (var session = tf.Session())
{
session.run(model);
int result = session.run(y);
Assert.AreEqual(result, 4);
}
}
[TestMethod]
public void Assign1()
{
var graph = tf.Graph().as_default();
var variable = tf.Variable(31, name: "tree");
var init = tf.global_variables_initializer();
var sess = tf.Session(graph);
sess.run(init);
NDArray result = sess.run(variable);
Assert.IsTrue((int)result == 31);
var assign = variable.assign(12);
result = sess.run(assign);
Assert.IsTrue((int)result == 12);
}
[TestMethod]
public void Assign2()
{
var v1 = tf.Variable(10.0f, name: "v1"); //tf.get_variable("v1", shape: new TensorShape(3), initializer: tf.zeros_initializer);
var inc_v1 = v1.assign(v1 + 1.0f);
// Add an op to initialize the variables.
var init_op = tf.global_variables_initializer();
using (var sess = tf.Session())
{
sess.run(init_op);
// o some work with the model.
inc_v1.op.run(session: sess);
}
}
/// <summary>
/// https://databricks.com/tensorflow/variables
/// </summary>
[TestMethod]
public void Add()
{
tf.Graph().as_default();
int result = 0;
Tensor x = tf.Variable(10, name: "x");
var init_op = tf.global_variables_initializer();
using (var session = tf.Session())
{
session.run(init_op);
for(int i = 0; i < 5; i++)
{
x = x + 1;
result = session.run(x);
print(result);
}
}
Assert.AreEqual(15, result);
}
}
}