@@ -40,43 +40,157 @@ public class CnnTextClassification : IExample
4040
4141 protected float loss_value = 0 ;
4242 int vocabulary_size = 50000 ;
43+ NDArray train_x , valid_x , train_y , valid_y ;
4344
4445 public bool Run ( )
4546 {
4647 PrepareData ( ) ;
4748
48- var graph = tf . Graph ( ) . as_default ( ) ;
49- return with ( tf . Session ( graph ) , sess =>
49+ Train ( ) ;
50+
51+ return true ;
52+ }
53+
54+ // TODO: this originally is an SKLearn utility function. it randomizes train and test which we don't do here
55+ private ( NDArray , NDArray , NDArray , NDArray ) train_test_split ( NDArray x , NDArray y , float test_size = 0.3f )
56+ {
57+ Console . WriteLine ( "Splitting in Training and Testing data..." ) ;
58+ int len = x . shape [ 0 ] ;
59+ //int classes = y.Data<int>().Distinct().Count();
60+ //int samples = len / classes;
61+ int train_size = ( int ) Math . Round ( len * ( 1 - test_size ) ) ;
62+ var train_x = x [ new Slice ( stop : train_size ) , new Slice ( ) ] ;
63+ var valid_x = x [ new Slice ( start : train_size ) , new Slice ( ) ] ;
64+ var train_y = y [ new Slice ( stop : train_size ) ] ;
65+ var valid_y = y [ new Slice ( start : train_size ) ] ;
66+ Console . WriteLine ( "\t DONE" ) ;
67+ return ( train_x , valid_x , train_y , valid_y ) ;
68+ }
69+
70+ private static void FillWithShuffledLabels ( int [ ] [ ] x , int [ ] y , int [ ] [ ] shuffled_x , int [ ] shuffled_y , Random random , Dictionary < int , HashSet < int > > labels )
71+ {
72+ int i = 0 ;
73+ var label_keys = labels . Keys . ToArray ( ) ;
74+ while ( i < shuffled_x . Length )
5075 {
51- if ( IsImportingGraph )
52- return RunWithImportedGraph ( sess , graph ) ;
53- else
54- return RunWithBuiltGraph ( sess , graph ) ;
55- } ) ;
76+ var key = label_keys [ random . Next ( label_keys . Length ) ] ;
77+ var set = labels [ key ] ;
78+ var index = set . First ( ) ;
79+ if ( set . Count == 0 )
80+ {
81+ labels . Remove ( key ) ; // remove the set as it is empty
82+ label_keys = labels . Keys . ToArray ( ) ;
83+ }
84+ shuffled_x [ i ] = x [ index ] ;
85+ shuffled_y [ i ] = y [ index ] ;
86+ i ++ ;
87+ }
5688 }
5789
58- protected virtual bool RunWithImportedGraph ( Session sess , Graph graph )
90+ private IEnumerable < ( NDArray , NDArray , int ) > batch_iter ( NDArray inputs , NDArray outputs , int batch_size , int num_epochs )
5991 {
60- var stopwatch = Stopwatch . StartNew ( ) ;
92+ var num_batches_per_epoch = ( len ( inputs ) - 1 ) / batch_size + 1 ;
93+ var total_batches = num_batches_per_epoch * num_epochs ;
94+ foreach ( var epoch in range ( num_epochs ) )
95+ {
96+ foreach ( var batch_num in range ( num_batches_per_epoch ) )
97+ {
98+ var start_index = batch_num * batch_size ;
99+ var end_index = Math . Min ( ( batch_num + 1 ) * batch_size , len ( inputs ) ) ;
100+ if ( end_index <= start_index )
101+ break ;
102+ yield return ( inputs [ new Slice ( start_index , end_index ) ] , outputs [ new Slice ( start_index , end_index ) ] , total_batches ) ;
103+ }
104+ }
105+ }
106+
107+ public void PrepareData ( )
108+ {
109+ // full dataset https://github.com/le-scientifique/torchDatasets/raw/master/dbpedia_csv.tar.gz
110+ var url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/data/dbpedia_subset.zip" ;
111+ Web . Download ( url , dataDir , "dbpedia_subset.zip" ) ;
112+ Compress . UnZip ( Path . Combine ( dataDir , "dbpedia_subset.zip" ) , Path . Combine ( dataDir , "dbpedia_csv" ) ) ;
113+
61114 Console . WriteLine ( "Building dataset..." ) ;
62- int [ ] [ ] x = null ;
63- int [ ] y = null ;
115+
64116 int alphabet_size = 0 ;
65117
66118 var word_dict = DataHelpers . build_word_dict ( TRAIN_PATH ) ;
67- // vocabulary_size = len(word_dict);
68- ( x , y ) = DataHelpers . build_word_dataset ( TRAIN_PATH , word_dict , WORD_MAX_LEN ) ;
119+ vocabulary_size = len ( word_dict ) ;
120+ var ( x , y ) = DataHelpers . build_word_dataset ( TRAIN_PATH , word_dict , WORD_MAX_LEN ) ;
69121
70122 Console . WriteLine ( "\t DONE " ) ;
71123
72124 var ( train_x , valid_x , train_y , valid_y ) = train_test_split ( x , y , test_size : 0.15f ) ;
73125 Console . WriteLine ( "Training set size: " + train_x . len ) ;
74126 Console . WriteLine ( "Test set size: " + valid_x . len ) ;
127+ }
75128
76- Console . WriteLine ( "Import graph..." ) ;
129+ public Graph ImportGraph ( )
130+ {
131+ var graph = tf . Graph ( ) . as_default ( ) ;
132+
133+ // download graph meta data
77134 var meta_file = "word_cnn.meta" ;
135+ var meta_path = Path . Combine ( "graph" , meta_file ) ;
136+ if ( File . GetLastWriteTime ( meta_path ) < new DateTime ( 2019 , 05 , 11 ) )
137+ {
138+ // delete old cached file which contains errors
139+ Console . WriteLine ( "Discarding cached file: " + meta_path ) ;
140+ File . Delete ( meta_path ) ;
141+ }
142+ var url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/" + meta_file ;
143+ Web . Download ( url , "graph" , meta_file ) ;
144+
145+ Console . WriteLine ( "Import graph..." ) ;
78146 tf . train . import_meta_graph ( Path . Join ( "graph" , meta_file ) ) ;
79- Console . WriteLine ( "\t DONE " + stopwatch . Elapsed ) ;
147+ Console . WriteLine ( "\t DONE " ) ;
148+
149+ return graph ;
150+ }
151+
152+ public Graph BuildGraph ( )
153+ {
154+ var graph = tf . Graph ( ) . as_default ( ) ;
155+
156+ var embedding_size = 128 ;
157+ var learning_rate = 0.001f ;
158+ var filter_sizes = new int [ 3 , 4 , 5 ] ;
159+ var num_filters = 100 ;
160+ var document_max_len = 100 ;
161+
162+ var x = tf . placeholder ( tf . int32 , new TensorShape ( - 1 , document_max_len ) , name : "x" ) ;
163+ var y = tf . placeholder ( tf . int32 , new TensorShape ( - 1 ) , name : "y" ) ;
164+ var is_training = tf . placeholder ( tf . @bool , new TensorShape ( ) , name : "is_training" ) ;
165+ var global_step = tf . Variable ( 0 , trainable : false ) ;
166+ var keep_prob = tf . where ( is_training , 0.5 , 1.0 ) ;
167+ Tensor x_emb = null ;
168+
169+ with ( tf . name_scope ( "embedding" ) , scope =>
170+ {
171+ var init_embeddings = tf . random_uniform ( new int [ ] { vocabulary_size , embedding_size } ) ;
172+ var embeddings = tf . get_variable ( "embeddings" , initializer : init_embeddings ) ;
173+ x_emb = tf . nn . embedding_lookup ( embeddings , x ) ;
174+ x_emb = tf . expand_dims ( x_emb , - 1 ) ;
175+ } ) ;
176+
177+ foreach ( var filter_size in filter_sizes )
178+ {
179+ var conv = tf . layers . conv2d (
180+ x_emb ,
181+ filters : num_filters ,
182+ kernel_size : new int [ ] { filter_size , embedding_size } ,
183+ strides : new int [ ] { 1 , 1 } ,
184+ padding : "VALID" ,
185+ activation : tf . nn . relu ( ) ) ;
186+ }
187+
188+ return graph ;
189+ }
190+
191+ private bool RunWithImportedGraph ( Session sess , Graph graph )
192+ {
193+ var stopwatch = Stopwatch . StartNew ( ) ;
80194
81195 sess . run ( tf . global_variables_initializer ( ) ) ;
82196 var saver = tf . train . Saver ( tf . global_variables ( ) ) ;
@@ -149,107 +263,12 @@ protected virtual bool RunWithImportedGraph(Session sess, Graph graph)
149263 return false ;
150264 }
151265
152- protected virtual bool RunWithBuiltGraph ( Session session , Graph graph )
153- {
154- Console . WriteLine ( "Building dataset..." ) ;
155- var ( x , y , alphabet_size ) = DataHelpers . build_char_dataset ( "train" , "word_cnn" , CHAR_MAX_LEN , DataLimit ) ;
156-
157- var ( train_x , valid_x , train_y , valid_y ) = train_test_split ( x , y , test_size : 0.15f ) ;
158-
159- ITextClassificationModel model = null ;
160- // todo train the model
161- return false ;
162- }
163-
164- // TODO: this originally is an SKLearn utility function. it randomizes train and test which we don't do here
165- private ( NDArray , NDArray , NDArray , NDArray ) train_test_split ( NDArray x , NDArray y , float test_size = 0.3f )
166- {
167- Console . WriteLine ( "Splitting in Training and Testing data..." ) ;
168- int len = x . shape [ 0 ] ;
169- //int classes = y.Data<int>().Distinct().Count();
170- //int samples = len / classes;
171- int train_size = ( int ) Math . Round ( len * ( 1 - test_size ) ) ;
172- var train_x = x [ new Slice ( stop : train_size ) , new Slice ( ) ] ;
173- var valid_x = x [ new Slice ( start : train_size ) , new Slice ( ) ] ;
174- var train_y = y [ new Slice ( stop : train_size ) ] ;
175- var valid_y = y [ new Slice ( start : train_size ) ] ;
176- Console . WriteLine ( "\t DONE" ) ;
177- return ( train_x , valid_x , train_y , valid_y ) ;
178- }
179-
180- private static void FillWithShuffledLabels ( int [ ] [ ] x , int [ ] y , int [ ] [ ] shuffled_x , int [ ] shuffled_y , Random random , Dictionary < int , HashSet < int > > labels )
181- {
182- int i = 0 ;
183- var label_keys = labels . Keys . ToArray ( ) ;
184- while ( i < shuffled_x . Length )
185- {
186- var key = label_keys [ random . Next ( label_keys . Length ) ] ;
187- var set = labels [ key ] ;
188- var index = set . First ( ) ;
189- if ( set . Count == 0 )
190- {
191- labels . Remove ( key ) ; // remove the set as it is empty
192- label_keys = labels . Keys . ToArray ( ) ;
193- }
194- shuffled_x [ i ] = x [ index ] ;
195- shuffled_y [ i ] = y [ index ] ;
196- i ++ ;
197- }
198- }
199-
200- private IEnumerable < ( NDArray , NDArray , int ) > batch_iter ( NDArray inputs , NDArray outputs , int batch_size , int num_epochs )
201- {
202- var num_batches_per_epoch = ( len ( inputs ) - 1 ) / batch_size + 1 ;
203- var total_batches = num_batches_per_epoch * num_epochs ;
204- foreach ( var epoch in range ( num_epochs ) )
205- {
206- foreach ( var batch_num in range ( num_batches_per_epoch ) )
207- {
208- var start_index = batch_num * batch_size ;
209- var end_index = Math . Min ( ( batch_num + 1 ) * batch_size , len ( inputs ) ) ;
210- if ( end_index <= start_index )
211- break ;
212- yield return ( inputs [ new Slice ( start_index , end_index ) ] , outputs [ new Slice ( start_index , end_index ) ] , total_batches ) ;
213- }
214- }
215- }
216-
217- public void PrepareData ( )
218- {
219- // full dataset https://github.com/le-scientifique/torchDatasets/raw/master/dbpedia_csv.tar.gz
220- var url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/data/dbpedia_subset.zip" ;
221- Web . Download ( url , dataDir , "dbpedia_subset.zip" ) ;
222- Compress . UnZip ( Path . Combine ( dataDir , "dbpedia_subset.zip" ) , Path . Combine ( dataDir , "dbpedia_csv" ) ) ;
223-
224- if ( IsImportingGraph )
225- {
226- // download graph meta data
227- var meta_file = "word_cnn.meta" ;
228- var meta_path = Path . Combine ( "graph" , meta_file ) ;
229- if ( File . GetLastWriteTime ( meta_path ) < new DateTime ( 2019 , 05 , 11 ) )
230- {
231- // delete old cached file which contains errors
232- Console . WriteLine ( "Discarding cached file: " + meta_path ) ;
233- File . Delete ( meta_path ) ;
234- }
235- url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/" + meta_file ;
236- Web . Download ( url , "graph" , meta_file ) ;
237- }
238- }
239-
240- public Graph ImportGraph ( )
241- {
242- throw new NotImplementedException ( ) ;
243- }
244-
245- public Graph BuildGraph ( )
246- {
247- throw new NotImplementedException ( ) ;
248- }
249-
250266 public bool Train ( )
251267 {
252- throw new NotImplementedException ( ) ;
268+ var graph = IsImportingGraph ? ImportGraph ( ) : BuildGraph ( ) ;
269+
270+ return with ( tf . Session ( graph ) , sess
271+ => RunWithImportedGraph ( sess , graph ) ) ;
253272 }
254273
255274 public bool Predict ( )
0 commit comments