diff --git a/ReadData_np.py b/ReadData_np.py index 2cf10dc..1eac0e5 100644 --- a/ReadData_np.py +++ b/ReadData_np.py @@ -6,6 +6,7 @@ import numpy as np import os import json +import random test_images_name = [] test_sketchs_name = [] @@ -108,11 +109,9 @@ def ReadData(sess, batch_size = 128, is_train = True): ineg = [] if is_train: - i = 0 while True: - sk_i = i - for j in range(len(train_triplets[i])): - sk_i = i + sk_i = random.randint(0, len(shoes_images)-1) + for j in range(len(train_triplets[sk_i])): im_pos_i = train_triplets[sk_i][j][0] im_neg_i = train_triplets[sk_i][j][1] @@ -128,10 +127,7 @@ def ReadData(sess, batch_size = 128, is_train = True): s = [] ipos = [] ineg = [] - - i += 1 - if(i >= len(train_triplets)): - i = 0 + else: i = 0 while True: @@ -177,6 +173,6 @@ def ReadData(sess, batch_size = 128, is_train = True): plt.show() ''' - a = ReadData(sess, 5, False) + a = ReadData(sess, 5, True) next(a) next(a) \ No newline at end of file