@@ -71,61 +71,32 @@ public TensorShape(TensorShapeProto proto)
7171 }
7272 }
7373
74- public TensorShape ( params object [ ] dims )
74+ public TensorShape ( params int [ ] dims )
7575 {
76- Array arr ;
76+ switch ( dims . Length )
77+ {
78+ case 0 : shape = new Shape ( new int [ 0 ] ) ; break ;
79+ case 1 : shape = Shape . Vector ( ( int ) dims [ 0 ] ) ; break ;
80+ case 2 : shape = Shape . Matrix ( dims [ 0 ] , dims [ 1 ] ) ; break ;
81+ default : shape = new Shape ( dims ) ; break ;
82+ }
83+ }
7784
78- if ( dims . Length == 1 )
85+ public TensorShape ( int [ ] [ ] dims )
86+ {
87+ if ( dims . Length == 1 )
7988 {
80- switch ( dims [ 0 ] )
89+ switch ( dims [ 0 ] . Length )
8190 {
82- case int [ ] intarr :
83- arr = intarr ;
84- break ;
85- case long [ ] longarr :
86- arr = longarr ;
87- break ;
88- case object [ ] objarr :
89- arr = objarr ;
90- break ;
91- case int _:
92- case long _:
93- arr = dims ;
94- break ;
95- case null : //==Binding.None
96- arr = dims ;
97- break ;
98- default :
99- Binding . print ( dims ) ;
100- throw new ArgumentException ( nameof ( dims ) ) ;
91+ case 0 : shape = new Shape ( new int [ 0 ] ) ; break ;
92+ case 1 : shape = Shape . Vector ( ( int ) dims [ 0 ] [ 0 ] ) ; break ;
93+ case 2 : shape = Shape . Matrix ( dims [ 0 ] [ 0 ] , dims [ 1 ] [ 2 ] ) ; break ;
94+ default : shape = new Shape ( dims [ 0 ] ) ; break ;
10195 }
102- } else
103- arr = dims ;
104-
105- var intdims = new int [ arr . Length ] ;
106- for ( int i = 0 ; i < arr . Length ; i ++ )
107- {
108- var val = arr . GetValue ( i ) ;
109- if ( val == Binding . None )
110- intdims [ i ] = - 1 ;
111- else
112- intdims [ i ] = Converts . ToInt32 ( val ) ;
11396 }
114-
115- switch ( intdims . Length )
97+ else
11698 {
117- case 0 :
118- shape = new Shape ( new int [ 0 ] ) ;
119- break ;
120- case 1 :
121- shape = Shape . Vector ( ( int ) intdims [ 0 ] ) ;
122- break ;
123- case 2 :
124- shape = Shape . Matrix ( intdims [ 0 ] , intdims [ 1 ] ) ;
125- break ;
126- default :
127- shape = new Shape ( intdims ) ;
128- break ;
99+ throw new NotImplementedException ( "TensorShape int[][] dims" ) ;
129100 }
130101 }
131102
@@ -232,8 +203,6 @@ public override string ToString()
232203 public static implicit operator TensorShape ( Shape shape ) => new TensorShape ( ( int [ ] ) shape . Dimensions . Clone ( ) ) ;
233204 public static implicit operator Shape ( TensorShape shape ) => new Shape ( ( int [ ] ) shape . dims . Clone ( ) ) ;
234205
235- public static implicit operator TensorShape ( object [ ] dims ) => new TensorShape ( dims ) ;
236-
237206 public static implicit operator int [ ] ( TensorShape shape ) => ( int [ ] ) shape . dims . Clone ( ) ; //we clone to avoid any changes
238207 public static implicit operator TensorShape ( int [ ] dims ) => new TensorShape ( dims ) ;
239208
@@ -260,16 +229,5 @@ public override string ToString()
260229
261230 public static explicit operator ( int , int , int , int , int , int , int , int ) ( TensorShape shape ) => shape . dims . Length == 8 ? ( shape . dims [ 0 ] , shape . dims [ 1 ] , shape . dims [ 2 ] , shape . dims [ 3 ] , shape . dims [ 4 ] , shape . dims [ 5 ] , shape . dims [ 6 ] , shape . dims [ 7 ] ) : ( 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ) ;
262231 public static implicit operator TensorShape ( ( int , int , int , int , int , int , int , int ) dims ) => new TensorShape ( dims . Item1 , dims . Item2 , dims . Item3 , dims . Item4 , dims . Item5 , dims . Item6 , dims . Item7 , dims . Item8 ) ;
263-
264- public static implicit operator TensorShape ( int ? [ ] dims ) => new TensorShape ( dims ) ;
265- public static implicit operator TensorShape ( int ? dim ) => new TensorShape ( dim ) ;
266- public static implicit operator TensorShape ( ( object , object ) dims ) => new TensorShape ( dims . Item1 , dims . Item2 ) ;
267- public static implicit operator TensorShape ( ( object , object , object ) dims ) => new TensorShape ( dims . Item1 , dims . Item2 , dims . Item3 ) ;
268- public static implicit operator TensorShape ( ( object , object , object , object ) dims ) => new TensorShape ( dims . Item1 , dims . Item2 , dims . Item3 , dims . Item4 ) ;
269- public static implicit operator TensorShape ( ( object , object , object , object , object ) dims ) => new TensorShape ( dims . Item1 , dims . Item2 , dims . Item3 , dims . Item4 , dims . Item5 ) ;
270- public static implicit operator TensorShape ( ( object , object , object , object , object , object ) dims ) => new TensorShape ( dims . Item1 , dims . Item2 , dims . Item3 , dims . Item4 , dims . Item5 , dims . Item6 ) ;
271- public static implicit operator TensorShape ( ( object , object , object , object , object , object , object ) dims ) => new TensorShape ( dims . Item1 , dims . Item2 , dims . Item3 , dims . Item4 , dims . Item5 , dims . Item6 , dims . Item7 ) ;
272- public static implicit operator TensorShape ( ( object , object , object , object , object , object , object , object ) dims ) => new TensorShape ( dims . Item1 , dims . Item2 , dims . Item3 , dims . Item4 , dims . Item5 , dims . Item6 , dims . Item7 , dims . Item8 ) ;
273-
274232 }
275233}
0 commit comments