我们从Python开源项目中,提取了以下5个代码示例,用于说明如何使用utils.get_minibatches_idx()。
def pre_process_config(model, train_size, valid_size, test_size): kf_valid = get_minibatches_idx(valid_size, model.validate_batch_size) kf_test = get_minibatches_idx(test_size, model.validate_batch_size) valid_freq = ParamConfig['valid_freq'] if valid_freq == -1: valid_freq = train_size // model.train_batch_size save_freq = ParamConfig['save_freq'] if save_freq == -1: save_freq = train_size // model.train_batch_size display_freq = ParamConfig['display_freq'] save_to = ParamConfig['save_to'] patience = ParamConfig['patience'] return kf_valid, kf_test, valid_freq, save_freq, display_freq, save_to, patience
def fit(self,X,y,sess): max_epochs = 20 # Split into training and validation sets X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.33, random_state=42) for epoch in range(max_epochs): start = time.time() train_indices = get_minibatches_idx(len(X_train), batch_size, shuffle=True) print("\nEpoch %d" % (epoch+1)) train_accs = [] for c,it in enumerate(train_indices): batch_train_x = [X_train[i] for i in it] batch_train_y = [y_train[i] for i in it] feed_dict = {self.x: batch_train_x, self.y: batch_train_y, self.deterministic: False} _,acc = sess.run([self.train_step,self.accuracy], feed_dict) train_accs.append(acc) #print(c,len(train_indices),acc) print("Training accuracy: %.3f" % np.mean(train_accs)) val_pred = self.predict(X_val,sess) y = np.argmax(y_val,axis=1) val_acc = np.mean(np.equal(val_pred,y)) print("Val accuracy: %.3f" % val_acc) print("Time taken: %.3fs" % (time.time() - start)) return
def predict(self,X,sess): indices = get_minibatches_idx(len(X), batch_size, shuffle=False) pred = [] for i in indices: batch_x = [X[j] for j in i] feed_dict = {self.x: batch_x, self.deterministic: True} pred_batch = sess.run(self.pred, feed_dict) pred.append(pred_batch) pred = np.concatenate(pred,axis=0) pred = np.argmax(pred,axis=1) pred = np.reshape(pred,(-1)) return pred
def train(model, data, words, params): start_time = time.time() counter = 0 try: for eidx in xrange(params.epochs): kf = utils.get_minibatches_idx(len(data), params.batchsize, shuffle=True) uidx = 0 for _, train_index in kf: uidx += 1 batch = [data[t] for t in train_index] for i in batch: i[0].populate_embeddings(words) i[1].populate_embeddings(words) (g1x, g1mask, g2x, g2mask, p1x, p1mask, p2x, p2mask) = getpairs(model, batch, params) cost = model.train_function(g1x, g2x, p1x, p2x, g1mask, g2mask, p1mask, p2mask) if np.isnan(cost) or np.isinf(cost): print 'NaN detected' if (utils.checkIfQuarter(uidx, len(kf))): if (params.save): counter += 1 utils.saveParams(model, params.outfile + str(counter) + '.pickle') if (params.evaluate): evaluate_all(model, words) sys.stdout.flush() #undo batch to save RAM for i in batch: i[0].representation = None i[1].representation = None i[0].unpopulate_embeddings() i[1].unpopulate_embeddings() #print 'Epoch ', (eidx+1), 'Update ', (uidx+1), 'Cost ', cost if (params.save): counter += 1 utils.saveParams(model, params.outfile + str(counter) + '.pickle') if (params.evaluate): evaluate_all(model, words) print 'Epoch ', (eidx + 1), 'Cost ', cost except KeyboardInterrupt: print "Training interupted" end_time = time.time() print "total time:", (end_time - start_time)