Python chainer 模块,Optimizer() 实例源码

我们从Python开源项目中,提取了以下8个代码示例,用于说明如何使用chainer.Optimizer()

项目:deep_metric_learning    作者:ronekko    | 项目源码 | 文件源码
def save(self, dir_name):
        dir_path = os.path.join(self._root_dir_path, dir_name)
        if not os.path.exists(dir_path):
            os.mkdir(dir_path)

        others = []
        for key, value in self.items():
            if key.startswith('_'):
                continue

            if isinstance(value, (np.ndarray, list)):
                np.save(os.path.join(dir_path, key + ".npy"), value)
            elif isinstance(value, (chainer.Chain, chainer.ChainList)):
                model_path = os.path.join(dir_path, "model.npz")
                chainer.serializers.save_npz(model_path, value)
            elif isinstance(value, chainer.Optimizer):
                optimizer_path = os.path.join(dir_path, "optimizer.npz")
                chainer.serializers.save_npz(optimizer_path, value)
            else:
                others.append("{}: {}".format(key, value))

        with open(os.path.join(dir_path, "log.txt"), "a") as f:
            text = "\n".join(others) + "\n"
            f.write(text)
项目:chainerrl    作者:chainer    | 项目源码 | 文件源码
def set_shared_states(a, b):
    assert isinstance(a, chainer.Optimizer)
    assert hasattr(a, 'target'), 'Optimizer.setup must be called first'
    for param_name, param in a.target.namedparams():
        ensure_initialized_update_rule(param)
        state = param.update_rule.state
        for state_name, state_val in b[param_name].items():
            s = state[state_name]
            state[state_name] = np.frombuffer(
                state_val,
                dtype=s.dtype).reshape(s.shape)
项目:chainerrl    作者:chainer    | 项目源码 | 文件源码
def extract_states_as_shared_arrays(optimizer):
    assert isinstance(optimizer, chainer.Optimizer)
    assert hasattr(optimizer, 'target'), 'Optimizer.setup must be called first'
    shared_arrays = {}
    for param_name, param in optimizer.target.namedparams():
        shared_arrays[param_name] = {}
        ensure_initialized_update_rule(param)
        state = param.update_rule.state
        for state_name, state_val in state.items():
            shared_arrays[param_name][
                state_name] = mp.RawArray('f', state_val.ravel())
    return shared_arrays
项目:chainerrl    作者:chainer    | 项目源码 | 文件源码
def as_shared_objects(obj):
    if isinstance(obj, tuple):
        return tuple(as_shared_objects(x) for x in obj)
    elif isinstance(obj, chainer.Link):
        return share_params_as_shared_arrays(obj)
    elif isinstance(obj, chainer.Optimizer):
        return share_states_as_shared_arrays(obj)
    elif isinstance(obj, mp.sharedctypes.Synchronized):
        return obj
    else:
        raise ValueError('')
项目:chainerrl    作者:chainer    | 项目源码 | 文件源码
def synchronize_to_shared_objects(obj, shared_memory):
    if isinstance(obj, tuple):
        return tuple(synchronize_to_shared_objects(o, s)
                     for o, s in zip(obj, shared_memory))
    elif isinstance(obj, chainer.Link):
        set_shared_params(obj, shared_memory)
        return obj
    elif isinstance(obj, chainer.Optimizer):
        set_shared_states(obj, shared_memory)
        return obj
    elif isinstance(obj, mp.sharedctypes.Synchronized):
        return shared_memory
    else:
        raise ValueError('')
项目:Comicolorization    作者:DwangoMediaVillage    | 项目源码 | 文件源码
def __init__(
            self,
            args,
            loss_maker,
            main_optimizer,
            main_lossfun,
            reinput_optimizer=None,
            reinput_lossfun=None,
            discriminator_optimizer=None,
            discriminator_lossfun=None,
            *_args, **kwargs
    ):
        # type: (any, comicolorization.loss.LossMaker, any, typing.Callable[[typing.Dict], any], typing.List[chainer.Optimizer], typing.Callable[[int, typing.Dict], any], any, typing.Callable[[typing.Dict], any], *any, **any) -> None
        optimizers = {'main': main_optimizer}
        if reinput_optimizer is not None:
            for i_reinput, optimizer in enumerate(reinput_optimizer):
                optimizers['reinput{}'.format(i_reinput)] = optimizer

        if discriminator_optimizer is not None:
            optimizers['discriminator'] = discriminator_optimizer

        super().__init__(optimizer=optimizers, *_args, **kwargs)

        # chainer.reporter cannot work on some optimizer focus same model
        if args.separate_backward_reinput and reinput_optimizer is None:
            reinput_optimizer = [main_optimizer for _ in range(len(args.loss_blend_ratio_reinput))]

        self.args = args
        self.loss_maker = loss_maker
        self.main_optimizer = main_optimizer
        self.main_lossfun = main_lossfun
        self.reinput_optimizer = reinput_optimizer
        self.reinput_lossfun = reinput_lossfun
        self.discriminator_optimizer = discriminator_optimizer
        self.discriminator_lossfun = discriminator_lossfun
项目:async-rl    作者:muupan    | 项目源码 | 文件源码
def set_shared_states(a, b):
    assert isinstance(a, chainer.Optimizer)
    assert hasattr(a, 'target'), 'Optimizer.setup must be called first'
    for state_name, shared_state in b.items():
        for param_name, param in shared_state.items():
            old_param = a._states[state_name][param_name]
            a._states[state_name][param_name] = np.frombuffer(
                param,
                dtype=old_param.dtype).reshape(old_param.shape)
项目:async-rl    作者:muupan    | 项目源码 | 文件源码
def extract_states_as_shared_arrays(optimizer):
    assert isinstance(optimizer, chainer.Optimizer)
    assert hasattr(optimizer, 'target'), 'Optimizer.setup must be called first'
    shared_arrays = {}
    for state_name, state in optimizer._states.items():
        shared_arrays[state_name] = {}
        for param_name, param in state.items():
            shared_arrays[state_name][
                param_name] = mp.RawArray('f', param.ravel())
    return shared_arrays