我们从Python开源项目中,提取了以下4个代码示例,用于说明如何使用chainer.serializers()。
def _build_model(self, config, src_vocab, trg_vocab): def convert(val): if val.isdigit(): return int(val) try: return float(val) except: return val model_config = config['Model'] kwargs = {k: convert(v) for k, v in model_config.items() if k != 'name'} m = getattr(models, model_config['name'])(**kwargs) model_path = os.path.join(self.save_dir, 'model.hdf') # load if os.path.exists(model_path): chainer.serializers.load_hdf5(model_path, m) xstoi = src_vocab.stoi ystoi = trg_vocab.stoi xbos = xstoi('<s>') xeos = xstoi('</s>') ybos = ystoi('<s>') yeos = ystoi('</s>') m.set_symbols(xbos, xeos, ybos, yeos) m.name = model_config['name'] m.byte = self._load_binary_config(config['Training'], 'byte') m.reverse_output = self._load_binary_config( config['Training'], 'reverse_output') if m.byte: m.vocab = trg_vocab return m
def save(self): save_dir = self.save_dir m = self.model.copy() m.name = self.model.name m.to_cpu() model_path = os.path.join(save_dir, 'model.hdf') chainer.serializers.save_hdf5(model_path, m) with open(os.path.join(save_dir, "vocab.pkl"), "wb") as f: pickle.dump((self.src_vcb, self.trg_vcb), f)
def load_params(prefix, mdl, opt): logger = logging.getLogger(__name__) logger.info('Loading model/optimizer parameters') chainer.serializers.load_npz(prefix + '.mdl', mdl) chainer.serializers.load_npz(prefix + '.opt', opt)
def save_params(prefix, mdl, opt): logger = logging.getLogger(__name__) logger.info('Saving model/optimizer parameters') chainer.serializers.save_npz(prefix + '.mdl', mdl) chainer.serializers.save_npz(prefix + '.opt', opt)