Skip to content

Commit 269611b

Browse files
committed
change None to Unknown for TensorShape. Remove object[] construction.
1 parent 420e195 commit 269611b

File tree

5 files changed

+41
-82
lines changed

5 files changed

+41
-82
lines changed

src/TensorFlowNET.Core/Binding.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,13 @@ public static partial class Binding
1010

1111
/// <summary>
1212
/// Alias to null, similar to python's None.
13+
/// For TensorShape, please use Unknown
1314
/// </summary>
1415
public static readonly object None = null;
16+
17+
/// <summary>
18+
/// Used for TensorShape None
19+
/// </summary>
20+
public static readonly int Unknown = -1;
1521
}
1622
}

src/TensorFlowNET.Core/Tensors/TensorShape.cs

Lines changed: 19 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -71,61 +71,32 @@ public TensorShape(TensorShapeProto proto)
7171
}
7272
}
7373

74-
public TensorShape(params object[] dims)
74+
public TensorShape(params int[] dims)
7575
{
76-
Array arr;
76+
switch (dims.Length)
77+
{
78+
case 0: shape = new Shape(new int[0]); break;
79+
case 1: shape = Shape.Vector((int)dims[0]); break;
80+
case 2: shape = Shape.Matrix(dims[0], dims[1]); break;
81+
default: shape = new Shape(dims); break;
82+
}
83+
}
7784

78-
if (dims.Length == 1)
85+
public TensorShape(int[][] dims)
86+
{
87+
if(dims.Length == 1)
7988
{
80-
switch (dims[0])
89+
switch (dims[0].Length)
8190
{
82-
case int[] intarr:
83-
arr = intarr;
84-
break;
85-
case long[] longarr:
86-
arr = longarr;
87-
break;
88-
case object[] objarr:
89-
arr = objarr;
90-
break;
91-
case int _:
92-
case long _:
93-
arr = dims;
94-
break;
95-
case null: //==Binding.None
96-
arr = dims;
97-
break;
98-
default:
99-
Binding.print(dims);
100-
throw new ArgumentException(nameof(dims));
91+
case 0: shape = new Shape(new int[0]); break;
92+
case 1: shape = Shape.Vector((int)dims[0][0]); break;
93+
case 2: shape = Shape.Matrix(dims[0][0], dims[1][2]); break;
94+
default: shape = new Shape(dims[0]); break;
10195
}
102-
} else
103-
arr = dims;
104-
105-
var intdims = new int[arr.Length];
106-
for (int i = 0; i < arr.Length; i++)
107-
{
108-
var val = arr.GetValue(i);
109-
if (val == Binding.None)
110-
intdims[i] = -1;
111-
else
112-
intdims[i] = Converts.ToInt32(val);
11396
}
114-
115-
switch (intdims.Length)
97+
else
11698
{
117-
case 0:
118-
shape = new Shape(new int[0]);
119-
break;
120-
case 1:
121-
shape = Shape.Vector((int) intdims[0]);
122-
break;
123-
case 2:
124-
shape = Shape.Matrix(intdims[0], intdims[1]);
125-
break;
126-
default:
127-
shape = new Shape(intdims);
128-
break;
99+
throw new NotImplementedException("TensorShape int[][] dims");
129100
}
130101
}
131102

@@ -232,8 +203,6 @@ public override string ToString()
232203
public static implicit operator TensorShape(Shape shape) => new TensorShape((int[]) shape.Dimensions.Clone());
233204
public static implicit operator Shape(TensorShape shape) => new Shape((int[]) shape.dims.Clone());
234205

235-
public static implicit operator TensorShape(object[] dims) => new TensorShape(dims);
236-
237206
public static implicit operator int[](TensorShape shape) => (int[])shape.dims.Clone(); //we clone to avoid any changes
238207
public static implicit operator TensorShape(int[] dims) => new TensorShape(dims);
239208

@@ -260,16 +229,5 @@ public override string ToString()
260229

261230
public static explicit operator (int, int, int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 8 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5], shape.dims[6], shape.dims[7]) : (0, 0, 0, 0, 0, 0, 0, 0);
262231
public static implicit operator TensorShape((int, int, int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6, dims.Item7, dims.Item8);
263-
264-
public static implicit operator TensorShape(int?[] dims) => new TensorShape(dims);
265-
public static implicit operator TensorShape(int? dim) => new TensorShape(dim);
266-
public static implicit operator TensorShape((object, object) dims) => new TensorShape(dims.Item1, dims.Item2);
267-
public static implicit operator TensorShape((object, object, object) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3);
268-
public static implicit operator TensorShape((object, object, object, object) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4);
269-
public static implicit operator TensorShape((object, object, object, object, object) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5);
270-
public static implicit operator TensorShape((object, object, object, object, object, object) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6);
271-
public static implicit operator TensorShape((object, object, object, object, object, object, object) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6, dims.Item7);
272-
public static implicit operator TensorShape((object, object, object, object, object, object, object, object) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6, dims.Item7, dims.Item8);
273-
274232
}
275233
}

src/TensorFlowNET.Core/tensorflow.cs

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,6 @@ public unsafe Tensor placeholder(TF_DataType dtype, TensorShape shape = null, st
6363
return gen_array_ops.placeholder(dtype, shape, name);
6464
}
6565

66-
public unsafe Tensor placeholder(TF_DataType dtype, object[] shape, string name = null)
67-
{
68-
return placeholder(dtype, new TensorShape(shape), name);
69-
}
70-
7166
public void enable_eager_execution()
7267
{
7368
// contex = new Context();

test/TensorFlowNET.UnitTest/TensorShapeTest.cs

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,48 +12,48 @@ public class TensorShapeTest
1212
[TestMethod]
1313
public void Case1()
1414
{
15-
int? a = 2;
16-
int? b = 3;
17-
var dims = new object[] {(int?) None, a, b};
15+
int a = 2;
16+
int b = 3;
17+
var dims = new [] { Unknown, a, b};
1818
new TensorShape(dims).GetPrivate<Shape>("shape").Should().BeShaped(-1, 2, 3);
1919
}
2020

2121
[TestMethod]
2222
public void Case2()
2323
{
24-
int? a = 2;
25-
int? b = 3;
26-
var dims = new object[] {(int?) None, a, b};
27-
new TensorShape(new object[] {dims}).GetPrivate<Shape>("shape").Should().BeShaped(-1, 2, 3);
24+
int a = 2;
25+
int b = 3;
26+
var dims = new[] { Unknown, a, b};
27+
new TensorShape(new [] {dims}).GetPrivate<Shape>("shape").Should().BeShaped(-1, 2, 3);
2828
}
2929

3030
[TestMethod]
3131
public void Case3()
3232
{
33-
int? a = 2;
34-
int? b = null;
35-
var dims = new object[] {(int?) None, a, b};
36-
new TensorShape(new object[] {dims}).GetPrivate<Shape>("shape").Should().BeShaped(-1, 2, -1);
33+
int a = 2;
34+
int b = Unknown;
35+
var dims = new [] { Unknown, a, b};
36+
new TensorShape(new [] {dims}).GetPrivate<Shape>("shape").Should().BeShaped(-1, 2, -1);
3737
}
3838

3939
[TestMethod]
4040
public void Case4()
4141
{
42-
TensorShape shape = (None, None);
42+
TensorShape shape = (Unknown, Unknown);
4343
shape.GetPrivate<Shape>("shape").Should().BeShaped(-1, -1);
4444
}
4545

4646
[TestMethod]
4747
public void Case5()
4848
{
49-
TensorShape shape = (1, None, 3);
49+
TensorShape shape = (1, Unknown, 3);
5050
shape.GetPrivate<Shape>("shape").Should().BeShaped(1, -1, 3);
5151
}
5252

5353
[TestMethod]
5454
public void Case6()
5555
{
56-
TensorShape shape = (None, 1, 2, 3, None);
56+
TensorShape shape = (Unknown, 1, 2, 3, Unknown);
5757
shape.GetPrivate<Shape>("shape").Should().BeShaped(-1, 1, 2, 3, -1);
5858
}
5959
}

test/TensorFlowNET.UnitTest/layers_test/flatten.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ public void Case4()
4242
{
4343
var sess = tf.Session().as_default();
4444

45-
var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(3, 4, None, 1, 2));
45+
var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(3, 4, Unknown, 1, 2));
4646
sess.run(tf.layers.flatten(input), (input, np.arange(3 * 4 * 3 * 1 * 2).reshape(3, 4, 3, 1, 2))).Should().BeShaped(3, 24);
4747
}
4848

@@ -51,7 +51,7 @@ public void Case5()
5151
{
5252
var sess = tf.Session().as_default();
5353

54-
var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(None, 4, 3, 1, 2));
54+
var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(Unknown, 4, 3, 1, 2));
5555
sess.run(tf.layers.flatten(input), (input, np.arange(3 * 4 * 3 * 1 * 2).reshape(3, 4, 3, 1, 2))).Should().BeShaped(3, 24);
5656
}
5757
}

0 commit comments

Comments
 (0)