Python utils 模块,get_minibatches_idx() 实例源码

我们从Python开源项目中,提取了以下5个代码示例,用于说明如何使用utils.get_minibatches_idx()

项目:RL4Data    作者:fyabc    | 项目源码 | 文件源码
def pre_process_config(model, train_size, valid_size, test_size):
    kf_valid = get_minibatches_idx(valid_size, model.validate_batch_size)
    kf_test = get_minibatches_idx(test_size, model.validate_batch_size)

    valid_freq = ParamConfig['valid_freq']
    if valid_freq == -1:
        valid_freq = train_size // model.train_batch_size

    save_freq = ParamConfig['save_freq']
    if save_freq == -1:
        save_freq = train_size // model.train_batch_size

    display_freq = ParamConfig['display_freq']
    save_to = ParamConfig['save_to']
    patience = ParamConfig['patience']

    return kf_valid, kf_test, valid_freq, save_freq, display_freq, save_to, patience
项目:variational-dropout    作者:cjratcliff    | 项目源码 | 文件源码
def fit(self,X,y,sess):
        max_epochs = 20

        # Split into training and validation sets
        X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.33, random_state=42)

        for epoch in range(max_epochs):
            start = time.time()
            train_indices = get_minibatches_idx(len(X_train), batch_size, shuffle=True)
            print("\nEpoch %d" % (epoch+1))

            train_accs = []
            for c,it in enumerate(train_indices):
                batch_train_x = [X_train[i] for i in it]
                batch_train_y = [y_train[i] for i in it]
                feed_dict = {self.x: batch_train_x, 
                            self.y: batch_train_y,
                            self.deterministic: False}
                _,acc = sess.run([self.train_step,self.accuracy], feed_dict)
                train_accs.append(acc)
                #print(c,len(train_indices),acc)
            print("Training accuracy: %.3f" % np.mean(train_accs))      
            val_pred = self.predict(X_val,sess)
            y = np.argmax(y_val,axis=1)
            val_acc = np.mean(np.equal(val_pred,y))
            print("Val accuracy: %.3f" % val_acc)
            print("Time taken: %.3fs" % (time.time() - start))
        return
项目:variational-dropout    作者:cjratcliff    | 项目源码 | 文件源码
def predict(self,X,sess):
        indices = get_minibatches_idx(len(X), batch_size, shuffle=False)
        pred = []
        for i in indices:
            batch_x = [X[j] for j in i]
            feed_dict = {self.x: batch_x, 
                        self.deterministic: True}       
            pred_batch = sess.run(self.pred, feed_dict)
            pred.append(pred_batch)
        pred = np.concatenate(pred,axis=0)
        pred = np.argmax(pred,axis=1)
        pred = np.reshape(pred,(-1))
        return pred
项目:Learning-sentence-representation-with-guidance-of-human-attention    作者:wangshaonan    | 项目源码 | 文件源码
def train(model, data, words, params):
    start_time = time.time()

    counter = 0
    try:
        for eidx in xrange(params.epochs):

            kf = utils.get_minibatches_idx(len(data), params.batchsize, shuffle=True)
            uidx = 0
            for _, train_index in kf:

                uidx += 1

                batch = [data[t] for t in train_index]
                for i in batch:
                    i[0].populate_embeddings(words)
                    i[1].populate_embeddings(words)

                (g1x, g1mask, g2x, g2mask, p1x, p1mask, p2x, p2mask) = getpairs(model, batch, params)

                cost = model.train_function(g1x, g2x, p1x, p2x, g1mask, g2mask, p1mask, p2mask)

                if np.isnan(cost) or np.isinf(cost):
                    print 'NaN detected'

                if (utils.checkIfQuarter(uidx, len(kf))):
                    if (params.save):
                        counter += 1
                        utils.saveParams(model, params.outfile + str(counter) + '.pickle')
                    if (params.evaluate):
                        evaluate_all(model, words)
                        sys.stdout.flush()

                #undo batch to save RAM
                for i in batch:
                    i[0].representation = None
                    i[1].representation = None
                    i[0].unpopulate_embeddings()
                    i[1].unpopulate_embeddings()

                #print 'Epoch ', (eidx+1), 'Update ', (uidx+1), 'Cost ', cost

            if (params.save):
                counter += 1
                utils.saveParams(model, params.outfile + str(counter) + '.pickle')

            if (params.evaluate):
                evaluate_all(model, words)

            print 'Epoch ', (eidx + 1), 'Cost ', cost

    except KeyboardInterrupt:
        print "Training interupted"

    end_time = time.time()
    print "total time:", (end_time - start_time)
项目:iclr2016    作者:jwieting    | 项目源码 | 文件源码
def train(model, data, words, params):
    start_time = time.time()

    counter = 0
    try:
        for eidx in xrange(params.epochs):

            kf = utils.get_minibatches_idx(len(data), params.batchsize, shuffle=True)
            uidx = 0
            for _, train_index in kf:

                uidx += 1

                batch = [data[t] for t in train_index]
                for i in batch:
                    i[0].populate_embeddings(words)
                    i[1].populate_embeddings(words)

                (g1x, g1mask, g2x, g2mask, p1x, p1mask, p2x, p2mask) = getpairs(model, batch, params)

                cost = model.train_function(g1x, g2x, p1x, p2x, g1mask, g2mask, p1mask, p2mask)

                if np.isnan(cost) or np.isinf(cost):
                    print 'NaN detected'

                if (utils.checkIfQuarter(uidx, len(kf))):
                    if (params.save):
                        counter += 1
                        utils.saveParams(model, params.outfile + str(counter) + '.pickle')
                    if (params.evaluate):
                        evaluate_all(model, words)
                        sys.stdout.flush()

                #undo batch to save RAM
                for i in batch:
                    i[0].representation = None
                    i[1].representation = None
                    i[0].unpopulate_embeddings()
                    i[1].unpopulate_embeddings()

                #print 'Epoch ', (eidx+1), 'Update ', (uidx+1), 'Cost ', cost

            if (params.save):
                counter += 1
                utils.saveParams(model, params.outfile + str(counter) + '.pickle')

            if (params.evaluate):
                evaluate_all(model, words)

            print 'Epoch ', (eidx + 1), 'Cost ', cost

    except KeyboardInterrupt:
        print "Training interupted"

    end_time = time.time()
    print "total time:", (end_time - start_time)