Skip to content

Commit a5ae56a

Browse files
committed
add tf.where api.
1 parent c42f8bb commit a5ae56a

File tree

9 files changed

+166
-124
lines changed

9 files changed

+166
-124
lines changed

src/TensorFlowNET.Core/APIs/tf.array.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,13 @@ public static partial class tf
2020
public static Tensor expand_dims(Tensor input, int axis = -1, string name = null, int dim = -1)
2121
=> array_ops.expand_dims(input, axis, name, dim);
2222

23+
/// <summary>
24+
/// Return the elements, either from `x` or `y`, depending on the `condition`.
25+
/// </summary>
26+
/// <returns></returns>
27+
public static Tensor where<Tx, Ty>(Tensor condition, Tx x, Ty y, string name = null)
28+
=> array_ops.where(condition, x, y, name);
29+
2330
/// <summary>
2431
/// Transposes `a`. Permutes the dimensions according to `perm`.
2532
/// </summary>

src/TensorFlowNET.Core/APIs/tf.random.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ public static Tensor random_normal(int[] shape,
2525

2626
public static Tensor random_uniform(int[] shape,
2727
float minval = 0,
28-
float? maxval = null,
28+
float maxval = 1,
2929
TF_DataType dtype = TF_DataType.TF_FLOAT,
3030
int? seed = null,
3131
string name = null) => random_ops.random_uniform(shape, minval, maxval, dtype, seed, name);

src/TensorFlowNET.Core/Layers/Layer.cs

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
using System.Collections.Generic;
33
using System.Linq;
44
using System.Text;
5+
using static Tensorflow.Python;
56

67
namespace Tensorflow.Layers
78
{
@@ -50,7 +51,7 @@ public Tensor __call__(Tensor inputs,
5051
auxiliary_name_scope: false);
5152
}
5253

53-
Python.with(scope_context_manager, scope2 => _current_scope = scope2);
54+
with(scope_context_manager, scope2 => _current_scope = scope2);
5455
// Actually call layer
5556
var outputs = base.__call__(new Tensor[] { inputs }, training: training);
5657

@@ -60,6 +61,13 @@ public Tensor __call__(Tensor inputs,
6061
return outputs;
6162
}
6263

64+
protected override void _init_set_name(string name, bool zero_based = true)
65+
{
66+
// Determine layer name (non-unique).
67+
base._init_set_name(name, zero_based: zero_based);
68+
_base_name = this.name;
69+
}
70+
6371
protected virtual void _add_elements_to_collection(Operation[] elements, string[] collection_list)
6472
{
6573
foreach(var name in collection_list)
@@ -140,10 +148,18 @@ private void _set_scope(VariableScope scope = null)
140148
{
141149
if (_scope == null)
142150
{
143-
Python.with(tf.variable_scope(scope, default_name: _base_name), captured_scope =>
151+
if(_reuse.HasValue && _reuse.Value)
144152
{
145-
_scope = captured_scope;
146-
});
153+
throw new NotImplementedException("_set_scope _reuse.HasValue");
154+
/*with(tf.variable_scope(scope == null ? _base_name : scope),
155+
captured_scope => _scope = captured_scope);*/
156+
}
157+
else
158+
{
159+
with(tf.variable_scope(scope, default_name: _base_name),
160+
captured_scope => _scope = captured_scope);
161+
}
162+
147163
}
148164
}
149165
}

src/TensorFlowNET.Core/Operations/array_ops.py.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ public static Tensor one_hot(Tensor indices, int depth,
233233
});
234234
}
235235

236-
public static Tensor where(Tensor condition, Tensor x = null, Tensor y = null, string name = null)
236+
public static Tensor where(Tensor condition, object x = null, object y = null, string name = null)
237237
{
238238
if( x == null && y == null)
239239
{

src/TensorFlowNET.Core/Operations/gen_array_ops.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ public static Tensor placeholder_with_default<T>(T input, int[] shape, string na
234234
return _op.outputs[0];
235235
}
236236

237-
public static Tensor select(Tensor condition, Tensor t, Tensor e, string name = null)
237+
public static Tensor select<Tx, Ty>(Tensor condition, Tx t, Ty e, string name = null)
238238
{
239239
var _op = _op_def_lib._apply_op_helper("Select", name, new { condition, t, e });
240240
return _op.outputs[0];

src/TensorFlowNET.Core/Operations/random_ops.py.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ public static Tensor random_normal(int[] shape,
4949
/// <returns>A tensor of the specified shape filled with random uniform values.</returns>
5050
public static Tensor random_uniform(int[] shape,
5151
float minval = 0,
52-
float? maxval = null,
52+
float maxval = 1,
5353
TF_DataType dtype = TF_DataType.TF_FLOAT,
5454
int? seed = null,
5555
string name = null)

src/TensorFlowNET.Core/tf.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ public static partial class tf
1515
public static TF_DataType float16 = TF_DataType.TF_HALF;
1616
public static TF_DataType float32 = TF_DataType.TF_FLOAT;
1717
public static TF_DataType float64 = TF_DataType.TF_DOUBLE;
18-
public static TF_DataType boolean = TF_DataType.TF_BOOL;
18+
public static TF_DataType @bool = TF_DataType.TF_BOOL;
1919
public static TF_DataType chars = TF_DataType.TF_STRING;
2020

2121
public static Context context = new Context(new ContextOptions(), new Status());

test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs

Lines changed: 133 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -40,43 +40,157 @@ public class CnnTextClassification : IExample
4040

4141
protected float loss_value = 0;
4242
int vocabulary_size = 50000;
43+
NDArray train_x, valid_x, train_y, valid_y;
4344

4445
public bool Run()
4546
{
4647
PrepareData();
4748

48-
var graph = tf.Graph().as_default();
49-
return with(tf.Session(graph), sess =>
49+
Train();
50+
51+
return true;
52+
}
53+
54+
// TODO: this originally is an SKLearn utility function. it randomizes train and test which we don't do here
55+
private (NDArray, NDArray, NDArray, NDArray) train_test_split(NDArray x, NDArray y, float test_size = 0.3f)
56+
{
57+
Console.WriteLine("Splitting in Training and Testing data...");
58+
int len = x.shape[0];
59+
//int classes = y.Data<int>().Distinct().Count();
60+
//int samples = len / classes;
61+
int train_size = (int)Math.Round(len * (1 - test_size));
62+
var train_x = x[new Slice(stop: train_size), new Slice()];
63+
var valid_x = x[new Slice(start: train_size), new Slice()];
64+
var train_y = y[new Slice(stop: train_size)];
65+
var valid_y = y[new Slice(start: train_size)];
66+
Console.WriteLine("\tDONE");
67+
return (train_x, valid_x, train_y, valid_y);
68+
}
69+
70+
private static void FillWithShuffledLabels(int[][] x, int[] y, int[][] shuffled_x, int[] shuffled_y, Random random, Dictionary<int, HashSet<int>> labels)
71+
{
72+
int i = 0;
73+
var label_keys = labels.Keys.ToArray();
74+
while (i < shuffled_x.Length)
5075
{
51-
if (IsImportingGraph)
52-
return RunWithImportedGraph(sess, graph);
53-
else
54-
return RunWithBuiltGraph(sess, graph);
55-
});
76+
var key = label_keys[random.Next(label_keys.Length)];
77+
var set = labels[key];
78+
var index = set.First();
79+
if (set.Count == 0)
80+
{
81+
labels.Remove(key); // remove the set as it is empty
82+
label_keys = labels.Keys.ToArray();
83+
}
84+
shuffled_x[i] = x[index];
85+
shuffled_y[i] = y[index];
86+
i++;
87+
}
5688
}
5789

58-
protected virtual bool RunWithImportedGraph(Session sess, Graph graph)
90+
private IEnumerable<(NDArray, NDArray, int)> batch_iter(NDArray inputs, NDArray outputs, int batch_size, int num_epochs)
5991
{
60-
var stopwatch = Stopwatch.StartNew();
92+
var num_batches_per_epoch = (len(inputs) - 1) / batch_size + 1;
93+
var total_batches = num_batches_per_epoch * num_epochs;
94+
foreach (var epoch in range(num_epochs))
95+
{
96+
foreach (var batch_num in range(num_batches_per_epoch))
97+
{
98+
var start_index = batch_num * batch_size;
99+
var end_index = Math.Min((batch_num + 1) * batch_size, len(inputs));
100+
if (end_index <= start_index)
101+
break;
102+
yield return (inputs[new Slice(start_index, end_index)], outputs[new Slice(start_index, end_index)], total_batches);
103+
}
104+
}
105+
}
106+
107+
public void PrepareData()
108+
{
109+
// full dataset https://github.com/le-scientifique/torchDatasets/raw/master/dbpedia_csv.tar.gz
110+
var url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/data/dbpedia_subset.zip";
111+
Web.Download(url, dataDir, "dbpedia_subset.zip");
112+
Compress.UnZip(Path.Combine(dataDir, "dbpedia_subset.zip"), Path.Combine(dataDir, "dbpedia_csv"));
113+
61114
Console.WriteLine("Building dataset...");
62-
int[][] x = null;
63-
int[] y = null;
115+
64116
int alphabet_size = 0;
65117

66118
var word_dict = DataHelpers.build_word_dict(TRAIN_PATH);
67-
// vocabulary_size = len(word_dict);
68-
(x, y) = DataHelpers.build_word_dataset(TRAIN_PATH, word_dict, WORD_MAX_LEN);
119+
vocabulary_size = len(word_dict);
120+
var (x, y) = DataHelpers.build_word_dataset(TRAIN_PATH, word_dict, WORD_MAX_LEN);
69121

70122
Console.WriteLine("\tDONE ");
71123

72124
var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f);
73125
Console.WriteLine("Training set size: " + train_x.len);
74126
Console.WriteLine("Test set size: " + valid_x.len);
127+
}
75128

76-
Console.WriteLine("Import graph...");
129+
public Graph ImportGraph()
130+
{
131+
var graph = tf.Graph().as_default();
132+
133+
// download graph meta data
77134
var meta_file = "word_cnn.meta";
135+
var meta_path = Path.Combine("graph", meta_file);
136+
if (File.GetLastWriteTime(meta_path) < new DateTime(2019, 05, 11))
137+
{
138+
// delete old cached file which contains errors
139+
Console.WriteLine("Discarding cached file: " + meta_path);
140+
File.Delete(meta_path);
141+
}
142+
var url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/" + meta_file;
143+
Web.Download(url, "graph", meta_file);
144+
145+
Console.WriteLine("Import graph...");
78146
tf.train.import_meta_graph(Path.Join("graph", meta_file));
79-
Console.WriteLine("\tDONE " + stopwatch.Elapsed);
147+
Console.WriteLine("\tDONE ");
148+
149+
return graph;
150+
}
151+
152+
public Graph BuildGraph()
153+
{
154+
var graph = tf.Graph().as_default();
155+
156+
var embedding_size = 128;
157+
var learning_rate = 0.001f;
158+
var filter_sizes = new int[3, 4, 5];
159+
var num_filters = 100;
160+
var document_max_len = 100;
161+
162+
var x = tf.placeholder(tf.int32, new TensorShape(-1, document_max_len), name: "x");
163+
var y = tf.placeholder(tf.int32, new TensorShape(-1), name: "y");
164+
var is_training = tf.placeholder(tf.@bool, new TensorShape(), name: "is_training");
165+
var global_step = tf.Variable(0, trainable: false);
166+
var keep_prob = tf.where(is_training, 0.5, 1.0);
167+
Tensor x_emb = null;
168+
169+
with(tf.name_scope("embedding"), scope =>
170+
{
171+
var init_embeddings = tf.random_uniform(new int[] { vocabulary_size, embedding_size });
172+
var embeddings = tf.get_variable("embeddings", initializer: init_embeddings);
173+
x_emb = tf.nn.embedding_lookup(embeddings, x);
174+
x_emb = tf.expand_dims(x_emb, -1);
175+
});
176+
177+
foreach(var filter_size in filter_sizes)
178+
{
179+
var conv = tf.layers.conv2d(
180+
x_emb,
181+
filters: num_filters,
182+
kernel_size: new int[] { filter_size, embedding_size },
183+
strides: new int[] { 1, 1 },
184+
padding: "VALID",
185+
activation: tf.nn.relu());
186+
}
187+
188+
return graph;
189+
}
190+
191+
private bool RunWithImportedGraph(Session sess, Graph graph)
192+
{
193+
var stopwatch = Stopwatch.StartNew();
80194

81195
sess.run(tf.global_variables_initializer());
82196
var saver = tf.train.Saver(tf.global_variables());
@@ -149,107 +263,12 @@ protected virtual bool RunWithImportedGraph(Session sess, Graph graph)
149263
return false;
150264
}
151265

152-
protected virtual bool RunWithBuiltGraph(Session session, Graph graph)
153-
{
154-
Console.WriteLine("Building dataset...");
155-
var (x, y, alphabet_size) = DataHelpers.build_char_dataset("train", "word_cnn", CHAR_MAX_LEN, DataLimit);
156-
157-
var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f);
158-
159-
ITextClassificationModel model = null;
160-
// todo train the model
161-
return false;
162-
}
163-
164-
// TODO: this originally is an SKLearn utility function. it randomizes train and test which we don't do here
165-
private (NDArray, NDArray, NDArray, NDArray) train_test_split(NDArray x, NDArray y, float test_size = 0.3f)
166-
{
167-
Console.WriteLine("Splitting in Training and Testing data...");
168-
int len = x.shape[0];
169-
//int classes = y.Data<int>().Distinct().Count();
170-
//int samples = len / classes;
171-
int train_size = (int)Math.Round(len * (1 - test_size));
172-
var train_x = x[new Slice(stop: train_size), new Slice()];
173-
var valid_x = x[new Slice(start: train_size), new Slice()];
174-
var train_y = y[new Slice(stop: train_size)];
175-
var valid_y = y[new Slice(start: train_size)];
176-
Console.WriteLine("\tDONE");
177-
return (train_x, valid_x, train_y, valid_y);
178-
}
179-
180-
private static void FillWithShuffledLabels(int[][] x, int[] y, int[][] shuffled_x, int[] shuffled_y, Random random, Dictionary<int, HashSet<int>> labels)
181-
{
182-
int i = 0;
183-
var label_keys = labels.Keys.ToArray();
184-
while (i < shuffled_x.Length)
185-
{
186-
var key = label_keys[random.Next(label_keys.Length)];
187-
var set = labels[key];
188-
var index = set.First();
189-
if (set.Count == 0)
190-
{
191-
labels.Remove(key); // remove the set as it is empty
192-
label_keys = labels.Keys.ToArray();
193-
}
194-
shuffled_x[i] = x[index];
195-
shuffled_y[i] = y[index];
196-
i++;
197-
}
198-
}
199-
200-
private IEnumerable<(NDArray, NDArray, int)> batch_iter(NDArray inputs, NDArray outputs, int batch_size, int num_epochs)
201-
{
202-
var num_batches_per_epoch = (len(inputs) - 1) / batch_size + 1;
203-
var total_batches = num_batches_per_epoch * num_epochs;
204-
foreach (var epoch in range(num_epochs))
205-
{
206-
foreach (var batch_num in range(num_batches_per_epoch))
207-
{
208-
var start_index = batch_num * batch_size;
209-
var end_index = Math.Min((batch_num + 1) * batch_size, len(inputs));
210-
if (end_index <= start_index)
211-
break;
212-
yield return (inputs[new Slice(start_index, end_index)], outputs[new Slice(start_index, end_index)], total_batches);
213-
}
214-
}
215-
}
216-
217-
public void PrepareData()
218-
{
219-
// full dataset https://github.com/le-scientifique/torchDatasets/raw/master/dbpedia_csv.tar.gz
220-
var url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/data/dbpedia_subset.zip";
221-
Web.Download(url, dataDir, "dbpedia_subset.zip");
222-
Compress.UnZip(Path.Combine(dataDir, "dbpedia_subset.zip"), Path.Combine(dataDir, "dbpedia_csv"));
223-
224-
if (IsImportingGraph)
225-
{
226-
// download graph meta data
227-
var meta_file = "word_cnn.meta";
228-
var meta_path = Path.Combine("graph", meta_file);
229-
if (File.GetLastWriteTime(meta_path) < new DateTime(2019, 05, 11))
230-
{
231-
// delete old cached file which contains errors
232-
Console.WriteLine("Discarding cached file: " + meta_path);
233-
File.Delete(meta_path);
234-
}
235-
url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/" + meta_file;
236-
Web.Download(url, "graph", meta_file);
237-
}
238-
}
239-
240-
public Graph ImportGraph()
241-
{
242-
throw new NotImplementedException();
243-
}
244-
245-
public Graph BuildGraph()
246-
{
247-
throw new NotImplementedException();
248-
}
249-
250266
public bool Train()
251267
{
252-
throw new NotImplementedException();
268+
var graph = IsImportingGraph ? ImportGraph() : BuildGraph();
269+
270+
return with(tf.Session(graph), sess
271+
=> RunWithImportedGraph(sess, graph));
253272
}
254273

255274
public bool Predict()

0 commit comments

Comments
 (0)