Python chainer.serializers 模块,save_npz() 实例源码

我们从Python开源项目中,提取了以下27个代码示例,用于说明如何使用chainer.serializers.save_npz()

项目:ROCStory_skipthought_baseline    作者:soskek    | 项目源码 | 文件源码
def save(model, optimizer, save_name, args):
    serializers.save_npz(save_name + "model", copy.deepcopy(model).to_cpu())
    serializers.save_npz(save_name + "optimizer", optimizer)
    print('save', save_name)
项目:LSTMVAE    作者:ashwatthaman    | 项目源码 | 文件源码
def train(args,encdec,model_name_base = "./{}/model/cvaehidden_kl_{}_{}_l{}.npz"):
    encdec.loadModel(model_name_base,args)
    if args.gpu >= 0:
        import cupy as cp
        global xp;
        xp = cp
        encdec.to_gpu()

    optimizer = optimizers.Adam()
    optimizer.setup(encdec)
    for e_i in range(encdec.epoch_now, args.epoch):
        encdec.setEpochNow(e_i)
        loss_sum = 0
        for tupl in encdec.getBatchGen(args):
            loss = encdec(tupl)
            loss_sum += loss.data

            encdec.cleargrads()
            loss.backward()
            optimizer.update()
        print("epoch{}:loss_sum:{}".format(e_i, loss_sum))
        model_name = model_name_base.format(args.dataname, args.dataname, e_i, args.n_latent)
        serializers.save_npz(model_name, encdec)
项目:trainer    作者:nutszebra    | 项目源码 | 文件源码
def save_optimizer(self, optimizer, path=''):
        """Save optimizer model

        Example:

        ::

            path = './test.optimizer'
            self.save_optimizer(optimizer, path)

        Args:
            optimizer (chainer.optimizers): optimizer
            path (str): path

        Returns:
            bool: True if saving successful
        """

        # if path is ''
        if path == '':
            path = str(self.save_optimizer_epoch) + '.optimizer'
        # increment self.nz_save_optimizer_epoch
        self.nz_save_optimizer_epoch += 1
        serializers.save_npz(path, optimizer)
        return True
项目:chainer-cf-nade    作者:dsanno    | 项目源码 | 文件源码
def progress_func(epoch, loss, accuracy, valid_loss, valid_accuracy, test_loss, test_accuracy):
        print 'epoch: {} done'.format(epoch)
        print('train mean loss={}, accuracy={}'.format(loss, accuracy))
        if valid_loss is not None and valid_accuracy is not None:
            print('valid mean loss={}, accuracy={}'.format(valid_loss, valid_accuracy))
        if test_loss is not None and test_accuracy is not None:
            print('test mean loss={}, accuracy={}'.format(test_loss, test_accuracy))
        if valid_accuracy < progress_state['valid_accuracy']:
            serializers.save_npz(args.output, net)
            progress_state['valid_accuracy'] = valid_accuracy
            progress_state['test_accuracy'] = test_accuracy
        if epoch % args.save_iter == 0:
            base, ext = os.path.splitext(args.output)
            serializers.save_npz('{0}_{1:04d}{2}'.format(base, epoch, ext), net)
        if args.lr_decay_iter > 0 and epoch % args.lr_decay_iter == 0:
            optimizer.alpha *= args.lr_decay_ratio
项目:chainer_pong    作者:icoxfog417    | 项目源码 | 文件源码
def save(self, index=0):
        fname = "pong.model" if index == 0 else "pong_{0}.model".format(index)
        path = os.path.join(self.model_path, fname)
        serializers.save_npz(path, self.q)
项目:chainer-object-detection    作者:dsanno    | 项目源码 | 文件源码
def main():
    args = parse_args()

    print("loading classifier model...")
    input_model = YOLOv2Classifier(args.input_class)
    serializers.load_npz(args.input_path, input_model)

    model = YOLOv2(args.output_class, args.box)
    copy_conv_layer(input_model, model, partial_layer)
    copy_bias_layer(input_model, model, partial_layer)
    copy_bn_layer(input_model, model, partial_layer)

    print("saving model to %s" % (args.output_path))
    serializers.save_npz(args.output_path, model)
项目:fcn    作者:wkentaro    | 项目源码 | 文件源码
def caffe_to_chainermodel(model, caffe_prototxt, caffemodel_path,
                          chainermodel_path):
    os.chdir(osp.dirname(caffe_prototxt))
    net = caffe.Net(caffe_prototxt, caffemodel_path, caffe.TEST)

    for name, param in net.params.iteritems():
        try:
            layer = getattr(model, name)
        except AttributeError:
            print('Skipping caffe layer: %s' % name)
            continue

        has_bias = True
        if len(param) == 1:
            has_bias = False

        print('{0}:'.format(name))
        # weight
        print('  - W: %s %s' % (param[0].data.shape, layer.W.data.shape))
        assert param[0].data.shape == layer.W.data.shape
        layer.W.data = param[0].data
        # bias
        if has_bias:
            print('  - b: %s %s' % (param[1].data.shape, layer.b.data.shape))
            assert param[1].data.shape == layer.b.data.shape
            layer.b.data = param[1].data
    S.save_npz(chainermodel_path, model)
项目:chainerrl    作者:chainer    | 项目源码 | 文件源码
def save(self, dirname):
        """Save internal states."""
        makedirs(dirname, exist_ok=True)
        for attr in self.saved_attributes:
            assert hasattr(self, attr)
            attr_value = getattr(self, attr)
            if isinstance(attr_value, AttributeSavingMixin):
                assert attr_value is not self, "Avoid an infinite loop"
                attr_value.save(os.path.join(dirname, attr))
            else:
                serializers.save_npz(
                    os.path.join(dirname, '{}.npz'.format(attr)),
                    getattr(self, attr))
项目:trainer    作者:nutszebra    | 项目源码 | 文件源码
def save_model(self, path=''):
        """Save chainer model

        Example:

        ::

            path = './test.model'
            self.save_model(path)

        Args:
            path (str): path

        Returns:
            bool: True if saving successful
        """

        # if gpu_flag is True, switch the model to gpu mode at last
        gpu_flag = False
        # if gpu mode, switch the model to cpu mode temporarily
        if self.model_is_cpu_mode() is False:
            self.to_cpu()
            gpu_flag = True
        # if path is ''
        if path == '':
            path = str(self.save_model_epoch) + '.model'
        self.nz_save_model_epoch += 1
        # increment self.nz_save_model_epoch
        serializers.save_npz(path, self)
        # if gpu_flag is True, switch the model to gpu mode at last
        if gpu_flag:
            self.to_gpu()
        return True
项目:MultimodalDL    作者:masataka46    | 项目源码 | 文件源码
def train_loop():
    # Trainer
    graph_generated = False
    while True:
        while data_q.empty():
            time.sleep(0.1)
        inp = data_q.get()
        if inp == 'end':  # quit
            res_q.put('end')
            break
        elif inp == 'train':  # restart training
            res_q.put('train')
            model.train = True
            continue
        elif inp == 'val':  # start validation
            res_q.put('val')
            serializers.save_npz(args.out, model)
            serializers.save_npz(args.outstate, optimizer)
            model.train = False
            continue

        volatile = 'off' if model.train else 'on'
        x = chainer.Variable(xp.asarray(inp[0]), volatile=volatile)
        t = chainer.Variable(xp.asarray(inp[1]), volatile=volatile)

        if model.train:
            optimizer.update(model, x, t)
            if not graph_generated:
                with open('graph.dot', 'w') as o:
                    o.write(computational_graph.build_computational_graph(
                        (model.loss,)).dump())
                print('generated graph', file=sys.stderr)
                graph_generated = True
        else:
            model(x, t)

        res_q.put((float(model.loss.data), float(model.accuracy.data)))
        del x, t

# Invoke threads
项目:MultimodalDL    作者:masataka46    | 项目源码 | 文件源码
def train_loop():
    # Trainer
    graph_generated = False
    while True:
        while data_q.empty():
            time.sleep(0.1)
        inp = data_q.get()
        if inp == 'end':  # quit
            res_q.put('end')
            break
        elif inp == 'train':  # restart training
            res_q.put('train')
            model.train = True
            continue
        elif inp == 'val':  # start validation
            res_q.put('val')
            serializers.save_npz(args.out, model)
            serializers.save_npz(args.outstate, optimizer)
            model.train = False
            continue

        volatile = 'off' if model.train else 'on'
        x = chainer.Variable(xp.asarray(inp[0]), volatile=volatile)
        t = chainer.Variable(xp.asarray(inp[1]), volatile=volatile)

        if model.train:
            optimizer.update(model, x, t)
            if not graph_generated:
                with open('graph.dot', 'w') as o:
                    o.write(computational_graph.build_computational_graph(
                        (model.loss,)).dump())
                print('generated graph', file=sys.stderr)
                graph_generated = True
        else:
            model(x, t)

        res_q.put((float(model.loss.data), float(model.accuracy.data)))
        del x, t

# Invoke threads
项目:chainer-deconv    作者:germanRos    | 项目源码 | 文件源码
def train_loop():
    # Trainer
    graph_generated = False
    while True:
        while data_q.empty():
            time.sleep(0.1)
        inp = data_q.get()
        if inp == 'end':  # quit
            res_q.put('end')
            break
        elif inp == 'train':  # restart training
            res_q.put('train')
            model.train = True
            continue
        elif inp == 'val':  # start validation
            res_q.put('val')
            serializers.save_npz(args.out, model)
            serializers.save_npz(args.outstate, optimizer)
            model.train = False
            continue

        volatile = 'off' if model.train else 'on'
        x = chainer.Variable(xp.asarray(inp[0]), volatile=volatile)
        t = chainer.Variable(xp.asarray(inp[1]), volatile=volatile)

        if model.train:
            optimizer.update(model, x, t)
            if not graph_generated:
                with open('graph.dot', 'w') as o:
                    o.write(computational_graph.build_computational_graph(
                        (model.loss,)).dump())
                print('generated graph', file=sys.stderr)
                graph_generated = True
        else:
            model(x, t)

        res_q.put((float(model.loss.data), float(model.accuracy.data)))
        del x, t

# Invoke threads
项目:chainer-deconv    作者:germanRos    | 项目源码 | 文件源码
def saveInfo(self, model, optimizer, epoch, outputFolder, saveEach):
        if(epoch % saveEach == 0):
            if(not os.path.exists(outputFolder)):
                os.makedirs(outputFolder)
            bname = outputFolder + '/' + model.getName() + '_' + str(epoch)
            serializers.save_npz(bname + '.model', model)
            serializers.save_npz(bname + '.state', optimizer)
项目:chainer-deconv    作者:germanRos    | 项目源码 | 文件源码
def saveInfo(self, model, optimizer, smanager, epoch, outputFolder, saveEach):
        #ipdb.set_trace()
        if(epoch % saveEach == 0):
            if(not os.path.exists(outputFolder)):
                os.makedirs(outputFolder)
            bname = outputFolder + '/' + model.getName() + '_' + str(epoch)
            serializers.save_npz(bname + '.model', model)
            serializers.save_npz(bname + '.state', optimizer)
            smanager.save(bname + '.stats')
项目:GUINNESS    作者:HirokiNakahara    | 项目源码 | 文件源码
def on_epoch_done(epoch, n, o, loss, acc, valid_loss, valid_acc, test_loss, test_acc):
        error = 100 * (1 - acc)
        valid_error = 100 * (1 - valid_acc)
        test_error = 100 * (1 - test_acc)
        print('epoch {} done'.format(epoch))
        print('train loss: {} error: {}'.format(loss, error))
        print('valid loss: {} error: {}'.format(valid_loss, valid_error))
        print('test  loss: {} error: {}'.format(test_loss, test_error))
        if valid_error < state['best_valid_error']:
            serializers.save_npz('{}.model'.format(model_prefix), n)
            serializers.save_npz('{}.state'.format(model_prefix), o)
            state['best_valid_error'] = valid_error
            state['best_test_error'] = test_error
        if args.save_iter > 0 and (epoch + 1) % args.save_iter == 0:
            serializers.save_npz('{}_{}.model'.format(model_prefix, epoch + 1), n)
            serializers.save_npz('{}_{}.state'.format(model_prefix, epoch + 1), o)
        # prevent divergence when using identity mapping model
        if args.model == 'identity_mapping' and epoch < 9:
            o.lr = 0.01 + 0.01 * (epoch + 1)
#        if len(lr_decay_iter) == 1 and (epoch + 1) % lr_decay_iter[0] == 0 or epoch + 1 in lr_decay_iter:
        # Note, "lr_decay_iter" should be a list object to store a training schedule,
        # However, to keep up with the Python3.5, I changed to an integer value...
        if (epoch + 1) % args.lr_decay_iter == 0 and epoch > 1:
            if hasattr(optimizer, 'alpha'):
                o.alpha *= 0.1
            else:
                o.lr *= 0.1
        clock = time.clock()
        print('elapsed time: {}'.format(clock - state['clock']))
        state['clock'] = clock

        with open(log_file_path, 'a') as f:
            f.write('{},{},{},{},{},{},{}\n'.format(epoch + 1, loss, error, valid_loss, valid_error, test_loss, test_error))
项目:rnn-morpheme-analyzer    作者:mitaki28    | 项目源码 | 文件源码
def save_param(out_dir, epoch, storage):
    serializers.save_npz(
        str(out_dir/model_name(epoch)),
        storage.model
    )
    serializers.save_npz(
        str(out_dir/optimizer_name(epoch)),
        storage.optimizer
    )
项目:chainer-cifar    作者:dsanno    | 项目源码 | 文件源码
def on_epoch_done(epoch, n, o, loss, acc, valid_loss, valid_acc, test_loss, test_acc, test_time):
        error = 100 * (1 - acc)
        print('epoch {} done'.format(epoch))
        print('train loss: {} error: {}'.format(loss, error))
        if valid_loss is not None:
            valid_error = 100 * (1 - valid_acc)
            print('valid loss: {} error: {}'.format(valid_loss, valid_error))
        else:
            valid_error = None
        if test_loss is not None:
            test_error = 100 * (1 - test_acc)
            print('test  loss: {} error: {}'.format(test_loss, test_error))
            print('test time: {}s'.format(test_time))
        else:
            test_error = None
        if valid_loss is not None and valid_error < state['best_valid_error']:
            serializers.save_npz('{}.model'.format(model_prefix), n)
            serializers.save_npz('{}.state'.format(model_prefix), o)
            state['best_valid_error'] = valid_error
            state['best_test_error'] = test_error
        elif valid_loss is None:
            serializers.save_npz('{}.model'.format(model_prefix), n)
            serializers.save_npz('{}.state'.format(model_prefix), o)
            state['best_test_error'] = test_error
        if args.save_iter > 0 and (epoch + 1) % args.save_iter == 0:
            serializers.save_npz('{}_{}.model'.format(model_prefix, epoch + 1), n)
            serializers.save_npz('{}_{}.state'.format(model_prefix, epoch + 1), o)
        # prevent divergence when using identity mapping model
        if args.model == 'identity_mapping' and epoch < 9:
            o.lr = 0.01 + 0.01 * (epoch + 1)
        clock = time.clock()
        print('elapsed time: {}'.format(clock - state['clock']))
        state['clock'] = clock
        with open(log_file_path, 'a') as f:
            f.write('{},{},{},{},{},{},{}\n'.format(epoch + 1, loss, error, valid_loss, valid_error, test_loss, test_error))
项目:chainercv    作者:chainer    | 项目源码 | 文件源码
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('caffemodel')
    parser.add_argument('output')
    args = parser.parse_args()

    model = SSDCaffeFunction(args.caffemodel)
    serializers.save_npz(args.output, model)
项目:doubleDQN    作者:masataka46    | 项目源码 | 文件源码
def agent_message(self, inMessage):
        if inMessage.startswith("freeze learning"):
            self.policyFrozen = True
            return "message understood, policy frozen"

        if inMessage.startswith("unfreeze learning"):
            self.policyFrozen = False
            return "message understood, policy unfrozen"

        if inMessage.startswith("save model"):
            serializers.save_npz('resume.model', self.DDQN.model) # save current model
            np.savez('stored_D012.npz', D0=self.DDQN.D[0], D1=self.DDQN.D[1], D2=self.DDQN.D[2])
            np.savez('stored_D34.npz', D3=self.DDQN.D[3], D4=self.DDQN.D[4])
            return "message understood, model saved"
项目:der-network    作者:soskek    | 项目源码 | 文件源码
def save(model, optimizer, vocab, save_name, args):
    serializers.save_npz(save_name+"model", copy.deepcopy(model).to_cpu())
    serializers.save_npz(save_name+"optimizer", optimizer)
    json.dump(vocab, open(save_name+"vocab.json", "w"))
    print('save', save_name)
项目:mlimages    作者:icoxfog417    | 项目源码 | 文件源码
def train(epoch=10, batch_size=32, gpu=False):
    if gpu:
        cuda.check_cuda_available()
    xp = cuda.cupy if gpu else np

    td = TrainingData(LABEL_FILE, img_root=IMAGES_ROOT, image_property=IMAGE_PROP)

    # make mean image
    if not os.path.isfile(MEAN_IMAGE_FILE):
        print("make mean image...")
        td.make_mean_image(MEAN_IMAGE_FILE)
    else:
        td.mean_image_file = MEAN_IMAGE_FILE

    # train model
    label_def = LabelingMachine.read_label_def(LABEL_DEF_FILE)
    model = alex.Alex(len(label_def))
    optimizer = optimizers.MomentumSGD(lr=0.01, momentum=0.9)
    optimizer.setup(model)
    epoch = epoch
    batch_size = batch_size

    print("Now our model is {0} classification task.".format(len(label_def)))
    print("begin training the model. epoch:{0} batch size:{1}.".format(epoch, batch_size))

    if gpu:
        model.to_gpu()

    for i in range(epoch):
        print("epoch {0}/{1}: (learning rate={2})".format(i + 1, epoch, optimizer.lr))
        td.shuffle(overwrite=True)

        for x_batch, y_batch in td.generate_batches(batch_size):
            x = chainer.Variable(xp.asarray(x_batch))
            t = chainer.Variable(xp.asarray(y_batch))

            optimizer.update(model, x, t)
            print("loss: {0}, accuracy: {1}".format(float(model.loss.data), float(model.accuracy.data)))

        serializers.save_npz(MODEL_FILE, model)
        optimizer.lr *= 0.97
项目:chainer_frmqn    作者:okdshin    | 项目源码 | 文件源码
def save(self, name):
        serializers.save_npz(name+".model", self.dqn.model)
        serializers.save_npz(name+".optimizer", self.dqn.optimizer)
项目:machine_learning_in_application    作者:icoxfog417    | 项目源码 | 文件源码
def save_model(self, model):
        if not os.path.exists(self.model_path):
            os.mkdir(self.model_path)
        timestamp = datetime.strftime(datetime.now(), "%Y%m%d%H%M%S")
        model_file = os.path.join(self.model_path, "./" + model.__class__.__name__.lower() + "_" + timestamp + ".model")
        serializers.save_npz(model_file, model)
项目:chainer-EWC    作者:okdshin    | 项目源码 | 文件源码
def train_tasks_continuosly(
        args, model, train, test, train2, test2, enable_ewc):
    # Train Task A or load trained model
    if os.path.exists("mlp_taskA.model") or args.skip_taskA:
        print("load taskA model")
        serializers.load_npz("./model50/mlp_taskA.model", model)
    else:
        print("train taskA")
        train_task(args, "train_task_a"+("_with_ewc" if enable_ewc else ""),
                   model, args.epoch, train,
                   {"TaskA": test}, args.batchsize)
        print("save the model")
        serializers.save_npz("mlp_taskA.model", model)

    if enable_ewc:
        print("enable EWC")
        model.compute_fisher(train)
        model.store_variables()

    # Train Task B
    print("train taskB")
    train_task(args, "train_task_ab"+("_with_ewc" if enable_ewc else ""),
               model, args.epoch, train2,
               {"TaskA": test, "TaskB": test2}, args.batchsize)
    print("save the model")
    serializers.save_npz(
            "mlp_taskAB"+("_with_ewc" if enable_ewc else "")+".model", model)
项目:chainer-stack-gan    作者:dsanno    | 项目源码 | 文件源码
def train(gen1, gen2, dis, optimizer_gen, optimizer_dis, images, epoch_num, output_path, lr_decay=10, save_epoch=1, batch_size=64, margin=20, out_image_dir=None, clip_rect=None):
    xp = gen1.xp
    out_image_row_num = 10
    out_image_col_num = 10
    z_out_image =  xp.random.normal(0, 1, (out_image_row_num * out_image_col_num, latent_size)).astype(np.float32)
    z_out_image = z_out_image / (xp.linalg.norm(z_out_image, axis=1, keepdims=True) + 1e-12)
    x_batch = np.zeros((batch_size, 3, image_size, image_size), dtype=np.float32)
    iterator = chainer.iterators.SerialIterator(images, batch_size)
    sum_loss_gen = 0
    sum_loss_dis = 0
    num_loss = 0
    last_clock = time.clock()
    for batch_images in iterator:
        for j, image in enumerate(batch_images):
            with io.BytesIO(image) as b:
                pixels = Image.open(b).convert('RGB')
                if clip_rect is not None:
                    offset_left = np.random.randint(-4, 5)
                    offset_top = np.random.randint(-4, 5)
                    pixels = pixels.crop((clip_rect[0] + offset_left, clip_rect[1] + offset_top) + clip_rect[2:])
                pixels = np.asarray(pixels.resize((image_size, image_size)), dtype=np.float32)
                pixels = pixels.transpose((2, 0, 1))
                x_batch[j,...] = pixels / 127.5 - 1
        loss_gen, loss_dis = update(gen1, gen2, dis, optimizer_gen, optimizer_dis, x_batch, margin)
        sum_loss_gen += loss_gen
        sum_loss_dis += loss_dis
        num_loss += 1
        if iterator.is_new_epoch:
            epoch = iterator.epoch
            current_clock = time.clock()
            print('epoch {} done {}s elapsed'.format(epoch, current_clock - last_clock))
            print('gen loss: {}'.format(sum_loss_gen / num_loss))
            print('dis loss: {}'.format(sum_loss_dis / num_loss))
            last_clock = current_clock
            sum_loss_gen = 0
            sum_loss_dis = 0
            num_loss = 0
            if iterator.epoch % lr_decay == 0:
                optimizer_gen.alpha *= 0.5
                optimizer_dis.alpha *= 0.5
            if iterator.epoch % save_epoch == 0:
                if out_image_dir is not None:
                    image = np.zeros((out_image_row_num * out_image_col_num, 3, image_size, image_size), dtype=np.uint8)
                    for i in six.moves.range(out_image_row_num):
                        with chainer.no_backprop_mode():
                            begin_index = i * out_image_col_num
                            end_index = (i + 1) * out_image_col_num
                            sub_image = gen2(gen1(z_out_image[begin_index:end_index], train=False), train=False).data
                            sub_image = ((cuda.to_cpu(sub_image) + 1) * 127.5)
                            image[begin_index:end_index, ...] = sub_image.clip(0, 255).astype(np.uint8)
                    image = image.reshape(out_image_row_num, out_image_col_num, 3, image_size, image_size)
                    image = image.transpose((0, 3, 1, 4, 2))
                    image = image.reshape((out_image_row_num * image_size, out_image_col_num * image_size, 3))
                    Image.fromarray(image).save(os.path.join(out_image_dir, '{0:04d}.png'.format(epoch)))
                serializers.save_npz('{0}_{1:03d}.gen.model'.format(output_path, epoch), gen2)
                serializers.save_npz('{0}_{1:03d}.gen.state'.format(output_path, epoch), optimizer_gen)
                serializers.save_npz('{0}_{1:03d}.dis.model'.format(output_path, epoch), dis)
                serializers.save_npz('{0}_{1:03d}.dis.state'.format(output_path, epoch), optimizer_dis)
            if iterator.epoch >= epoch_num:
                break
项目:rnn-morpheme-analyzer    作者:mitaki28    | 项目源码 | 文件源码
def run_training(args):
    out_dir = pathlib.Path(args.directory)
    sentences = dataset.load(args.source)

    if args.epoch is not None:
        start = args.epoch + 1
        storage = load(out_dir, args.epoch)
        sentences = itertools.islice(sentences, start, None)
    else:
        start = 0
        storage = init(args)        
        if (out_dir/meta_name).exists():
            if input('Overwrite? [y/N]: ').strip().lower() != 'y':
                exit(1)
        with (out_dir/meta_name).open('wb') as f:
            np.save(f, [storage])

    batchsize = 5000
    for i, sentence in enumerate(sentences, start):
        if i % batchsize == 0:
            print()
            serializers.save_npz(
                str(out_dir/model_name(i)),
                storage.model
            )
            serializers.save_npz(
                str(out_dir/optimizer_name(i)),
                storage.optimizer
            )
        else:
            print(
                util.progress(
                    'batch {}'.format(i // batchsize),
                    (i % batchsize) / batchsize, 100),
                end=''
            )
        train(storage.model,
              storage.optimizer,
              generate_data(sentence),
              generate_label(sentence),
              generate_attr(
                  sentence,
                  storage.mappings
              )
        )
项目:chainer_sklearn    作者:corochann    | 项目源码 | 文件源码
def main():
    parser = argparse.ArgumentParser(description='Chainer example: MNIST')
    parser.add_argument('--batchsize', '-b', type=int, default=100,
                        help='Number of images in each mini-batch')
    parser.add_argument('--epoch', '-e', type=int, default=20,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--frequency', '-f', type=int, default=-1,
                        help='Frequency of taking a snapshot')
    parser.add_argument('--gpu', '-g', type=int, default=-1,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--out', '-o', default='result',
                        help='Directory to output the result')
    parser.add_argument('--resume', '-r', default='',
                        help='Resume the training from snapshot')
    parser.add_argument('--unit', '-u', type=int, default=50,
                        help='Number of units')
    parser.add_argument('--example', '-ex', type=int, default=3,
                        help='Example mode')
    args = parser.parse_args()

    print('GPU: {}'.format(args.gpu))
    print('# unit: {}'.format(args.unit))
    print('# Minibatch-size: {}'.format(args.batchsize))
    print('# epoch: {}'.format(args.epoch))
    print('')

    # Load the MNIST dataset
    train, test = chainer.datasets.get_mnist()

    model = SklearnWrapperClassifier(MLP(args.unit, 10), device=args.gpu)

    if args.example == 1:
        print("Example 1. fit with x, y numpy array (same with sklearn's fit)")
        x, y = concat_examples(train)
        model.fit(x, y)
    elif args.example == 2:
        print("Example 2. Train with Chainer's dataset")
        # `train` is TupleDataset in this example
        # Even this one line work! (but no validation)
        model.fit(train)
    else:
        print("Example 3. Train with configuration")
        model.fit(
            train,
            test=test,
            batchsize=args.batchsize,
            #iterator_class=chainer.iterators.SerialIterator,
            optimizer=chainer.optimizers.Adam(),
            epoch=args.epoch,
            out=args.out,
            snapshot_frequency=1,
            #dump_graph=False
            #log_report=True,
            plot_report=False,
            #print_report=True,
            progress_report=False,
            resume=args.resume
        )

    # Save trained model
    serializers.save_npz('{}/mlp.model'.format(args.out), model)