Skip to content

Commit b1bd05c

Browse files
committed
Expose tf.shape() API.
1 parent d2fef2f commit b1bd05c

File tree

5 files changed

+20
-6
lines changed

5 files changed

+20
-6
lines changed

src/TensorFlowNET.Core/APIs/tf.array.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,5 +124,15 @@ public Tensor pad(Tensor tensor, Tensor paddings, string mode = "CONSTANT", stri
124124
/// <returns>A `Tensor`. Has the same type as `input`.</returns>
125125
public Tensor placeholder_with_default<T>(T input, int[] shape, string name = null)
126126
=> gen_array_ops.placeholder_with_default(input, shape, name: name);
127+
128+
/// <summary>
129+
/// Returns the shape of a tensor.
130+
/// </summary>
131+
/// <param name="input"></param>
132+
/// <param name="name"></param>
133+
/// <param name="out_type"></param>
134+
/// <returns></returns>
135+
public Tensor shape(Tensor input, string name = null, TF_DataType out_type = TF_DataType.TF_INT32)
136+
=> array_ops.shape_internal(input, name, optimize: true, out_type: out_type);
127137
}
128138
}

src/TensorFlowNET.Core/Operations/array_ops.py.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ public static Tensor shape(Tensor input, string name = null, TF_DataType out_typ
338338
public static Tensor size(Tensor input, string name = null, bool optimize = true, TF_DataType out_type = TF_DataType.TF_INT32)
339339
=> size_internal(input, name, optimize: optimize, out_type: out_type);
340340

341-
private static Tensor shape_internal(Tensor input, string name = null, bool optimize = true, TF_DataType out_type = TF_DataType.TF_INT32)
341+
public static Tensor shape_internal(Tensor input, string name = null, bool optimize = true, TF_DataType out_type = TF_DataType.TF_INT32)
342342
{
343343
return tf_with(ops.name_scope(name, "Shape", new { input }), scope =>
344344
{

src/TensorFlowNET.Core/TensorFlowNET.Core.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ Docs: https://tensorflownet.readthedocs.io</Description>
2626
5. Overload session.run(), make syntax simpler.
2727
6. Add Local Response Normalization.
2828
7. Add tf.image related APIs.
29-
8. Add tf.random_normal, tf.constant, tf.pad.
29+
8. Add tf.random_normal, tf.constant, tf.pad, tf.shape.
3030
9. MultiThread is safe.</PackageReleaseNotes>
3131
<LangVersion>7.3</LangVersion>
3232
<FileVersion>0.11.2.0</FileVersion>

test/TensorFlowNET.Examples/ImageProcessing/YOLO/common.cs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,14 @@ public static Tensor upsample(Tensor input_data, string name, string method = "d
6565
Tensor output = null;
6666
if (method == "resize")
6767
{
68-
68+
tf_with(tf.variable_scope(name), delegate
69+
{
70+
var input_shape = tf.shape(input_data);
71+
});
6972
}
7073
else if(method == "deconv")
7174
{
72-
75+
throw new NotImplementedException("upsample.deconv");
7376
}
7477

7578
return output;

test/TensorFlowNET.UnitTest/ImageTest.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@ namespace TensorFlowNET.UnitTest
1414
[TestClass]
1515
public class ImageTest
1616
{
17-
string imgPath = "../../../../../data/shasta-daisy.jpg";
17+
string imgPath = "shasta-daisy.jpg";
1818
Tensor contents;
1919

20-
public ImageTest()
20+
[TestInitialize]
21+
public void Initialize()
2122
{
2223
imgPath = Path.GetFullPath(imgPath);
2324
contents = tf.read_file(imgPath);

0 commit comments

Comments
 (0)