Python joblib 模块,dump() 实例源码

我们从Python开源项目中,提取了以下50个代码示例,用于说明如何使用joblib.dump()

项目:py-smps    作者:dhhagan    | 项目源码 | 文件源码
def dump(self, filepath):
        """Save the SMPS object to disk"""
        return joblib.dump(self, filepath)
项目:third_person_im    作者:bstadie    | 项目源码 | 文件源码
def save_itr_params(itr, params):
    if _snapshot_dir:
        if _snapshot_mode == 'all':
            file_name = osp.join(_snapshot_dir, 'itr_%d.pkl' % itr)
            joblib.dump(params, file_name, compress=3)
        elif _snapshot_mode == 'last':
            # override previous params
            file_name = osp.join(_snapshot_dir, 'params.pkl')
            joblib.dump(params, file_name, compress=3)
        elif _snapshot_mode == "gap":
            if itr % _snapshot_gap == 0:
                file_name = osp.join(_snapshot_dir, 'itr_%d.pkl' % itr)
                joblib.dump(params, file_name, compress=3)
        elif _snapshot_mode == 'none':
            pass
        else:
            raise NotImplementedError
项目:third_person_im    作者:bstadie    | 项目源码 | 文件源码
def log_parameters(log_file, args, classes):
    log_params = {}
    for param_name, param_value in args.__dict__.items():
        if any([param_name.startswith(x) for x in list(classes.keys())]):
            continue
        log_params[param_name] = param_value
    for name, cls in classes.items():
        if isinstance(cls, type):
            params = get_all_parameters(cls, args)
            params["_name"] = getattr(args, name)
            log_params[name] = params
        else:
            log_params[name] = getattr(cls, "__kwargs", dict())
            log_params[name]["_name"] = cls.__module__ + "." + cls.__class__.__name__
    mkdir_p(os.path.dirname(log_file))
    with open(log_file, "w") as f:
        json.dump(log_params, f, indent=2, sort_keys=True)
项目:third_person_im    作者:bstadie    | 项目源码 | 文件源码
def log_parameters_lite(log_file, args):
    log_params = {}
    for param_name, param_value in args.__dict__.items():
        log_params[param_name] = param_value
    if args.args_data is not None:
        stub_method = pickle.loads(base64.b64decode(args.args_data))
        method_args = stub_method.kwargs
        log_params["json_args"] = dict()
        for k, v in list(method_args.items()):
            log_params["json_args"][k] = stub_to_json(v)
        kwargs = stub_method.obj.kwargs
        for k in ["baseline", "env", "policy"]:
            if k in kwargs:
                log_params["json_args"][k] = stub_to_json(kwargs.pop(k))
        log_params["json_args"]["algo"] = stub_to_json(stub_method.obj)
    mkdir_p(os.path.dirname(log_file))
    with open(log_file, "w") as f:
        json.dump(log_params, f, indent=2, sort_keys=True, cls=MyEncoder)
项目:scikitcrf_NER    作者:ManikandanThangavelu    | 项目源码 | 文件源码
def train(filePath):
    try:
        if not filePath.lower().endswith('json'):
            return {'success':False,'message':'Training file should be in json format'}
        with open(filePath) as file:
            ent_data = json.load(file)
        dataset = [jsonToCrf(q, nlp) for q in ent_data['entity_examples']]
        X_train = [sent2features(s) for s in dataset]
        y_train = [sent2labels(s) for s in dataset]
        crf = sklearn_crfsuite.CRF(
        algorithm='lbfgs', 
        c1=0.1, 
        c2=0.1, 
        max_iterations=100, 
        all_possible_transitions=True
        )
        crf.fit(X_train, y_train)
        if(not os.path.exists("crfModel")):
            os.mkdir("crfModel")
        if(os.path.isfile("crfModel/classifier.pkl")):
            os.remove("crfModel/classifier.pkl")
        joblib.dump(crf,"crfModel/classifier.pkl")
        return {'success':True,'message':'Model Trained Successfully'}
    except Exception as ex:
        return {'success':False,'message':'Error while Training the model - '+str(ex)}
项目:rllabplusplus    作者:shaneshixiang    | 项目源码 | 文件源码
def log_parameters(log_file, args, classes):
    log_params = {}
    for param_name, param_value in args.__dict__.items():
        if any([param_name.startswith(x) for x in list(classes.keys())]):
            continue
        log_params[param_name] = param_value
    for name, cls in classes.items():
        if isinstance(cls, type):
            params = get_all_parameters(cls, args)
            params["_name"] = getattr(args, name)
            log_params[name] = params
        else:
            log_params[name] = getattr(cls, "__kwargs", dict())
            log_params[name]["_name"] = cls.__module__ + "." + cls.__class__.__name__
    mkdir_p(os.path.dirname(log_file))
    with open(log_file, "w") as f:
        json.dump(log_params, f, indent=2, sort_keys=True)
项目:rllabplusplus    作者:shaneshixiang    | 项目源码 | 文件源码
def log_parameters_lite(log_file, args, json_kwargs=dict()):
    log_params = {}
    for param_name, param_value in args.__dict__.items():
        log_params[param_name] = param_value
    if args.args_data is not None:
        stub_method = pickle.loads(base64.b64decode(args.args_data))
        method_args = stub_method.kwargs
        log_params["json_args"] = dict()
        for k, v in list(json_kwargs.items()):
            log_params["json_args"][k] = v
        for k, v in list(method_args.items()):
            log_params["json_args"][k] = stub_to_json(v)
        kwargs = stub_method.obj.kwargs
        for k in ["baseline", "env", "policy"]:
            if k in kwargs:
                log_params["json_args"][k] = stub_to_json(kwargs.pop(k))
        log_params["json_args"]["algo"] = stub_to_json(stub_method.obj)
    mkdir_p(os.path.dirname(log_file))
    with open(log_file, "w") as f:
        json.dump(log_params, f, indent=2, sort_keys=True, cls=MyEncoder)
项目:rllabplusplus    作者:shaneshixiang    | 项目源码 | 文件源码
def save(self, checkpoint_dir=None):
        if checkpoint_dir is None: checkpoint_dir = logger.get_snapshot_dir()

        pool_file = os.path.join(checkpoint_dir, 'pool.chk')
        if self.save_format == 'pickle':
            pickle_dump(pool_file + '.tmp', self.pool)
        elif self.save_format == 'joblib':
            joblib.dump(self.pool, pool_file + '.tmp', compress=1, cache_size=1e9)
        else: raise NotImplementedError
        shutil.move(pool_file + '.tmp', pool_file)

        checkpoint_file = os.path.join(checkpoint_dir, 'params.chk')
        sess = tf.get_default_session()
        saver = tf.train.Saver()
        saver.save(sess, checkpoint_file)

        tabular_file = os.path.join(checkpoint_dir, 'progress.csv')
        if os.path.isfile(tabular_file):
            tabular_chk_file = os.path.join(checkpoint_dir, 'progress.csv.chk')
            shutil.copy(tabular_file, tabular_chk_file)

        logger.log('Saved to checkpoint %s'%checkpoint_file)
项目:CIKM2017    作者:heliarmk    | 项目源码 | 文件源码
def read_data_from_pkl(datafile):
    """
    read file in joblib.dump pkl
    :param datafile: filename of pkl
    :return: 
    """
    datas = joblib.load(datafile)
    for i in range(10):
        datas = np.random.permutation(datas)
    inputs, labels = [], []
    for data in datas:
        inputs.append(data["input"])
        labels.append(data["label"])

    inputs = np.array(inputs).reshape(-1, 15, 101, 101, 3).astype(np.float32)
    inputs -= np.mean(inputs, axis=(2, 3), keepdims=True)
    inputs /= np.std(inputs, axis=(2, 3), keepdims=True)
    labels = np.array(labels).reshape(-1, 1).astype(np.float32)

    return inputs, labels
项目:gail-driver    作者:sisl    | 项目源码 | 文件源码
def save_itr_params(itr, params):
    if _snapshot_dir:
        if _snapshot_mode == 'all':
            file_name = osp.join(_snapshot_dir, 'itr_%d.pkl' % itr)
            joblib.dump(params, file_name, compress=3)
        elif _snapshot_mode == 'last':
            # override previous params
            file_name = osp.join(_snapshot_dir, 'params.pkl')
            joblib.dump(params, file_name, compress=3)
        elif _snapshot_mode == "gap":
            if itr % _snapshot_gap == 0:
                file_name = osp.join(_snapshot_dir, 'itr_%d.pkl' % itr)
                joblib.dump(params, file_name, compress=3)
        elif _snapshot_mode == 'none':
            pass
        else:
            raise NotImplementedError
项目:gail-driver    作者:sisl    | 项目源码 | 文件源码
def log_parameters(log_file, args, classes):
    log_params = {}
    for param_name, param_value in args.__dict__.items():
        if any([param_name.startswith(x) for x in list(classes.keys())]):
            continue
        log_params[param_name] = param_value
    for name, cls in classes.items():
        if isinstance(cls, type):
            params = get_all_parameters(cls, args)
            params["_name"] = getattr(args, name)
            log_params[name] = params
        else:
            log_params[name] = getattr(cls, "__kwargs", dict())
            log_params[name]["_name"] = cls.__module__ + \
                "." + cls.__class__.__name__
    mkdir_p(os.path.dirname(log_file))
    with open(log_file, "w") as f:
        json.dump(log_params, f, indent=2, sort_keys=True)
项目:gail-driver    作者:sisl    | 项目源码 | 文件源码
def log_parameters_lite(log_file, args):
    log_params = {}
    for param_name, param_value in args.__dict__.items():
        log_params[param_name] = param_value
    if args.args_data is not None:
        stub_method = pickle.loads(base64.b64decode(args.args_data))
        method_args = stub_method.kwargs
        log_params["json_args"] = dict()
        for k, v in list(method_args.items()):
            log_params["json_args"][k] = stub_to_json(v)
        kwargs = stub_method.obj.kwargs
        for k in ["baseline", "env", "policy"]:
            if k in kwargs:
                log_params["json_args"][k] = stub_to_json(kwargs.pop(k))
        log_params["json_args"]["algo"] = stub_to_json(stub_method.obj)
    mkdir_p(os.path.dirname(log_file))
    with open(log_file, "w") as f:
        json.dump(log_params, f, indent=2, sort_keys=True, cls=MyEncoder)
项目:cesium_web    作者:cesium-ml    | 项目源码 | 文件源码
def add_file(model, create, value, *args, **kwargs):
        model_params = {
            "RandomForestClassifier": {
                "bootstrap": True, "criterion": "gini",
                "oob_score": False, "max_features": "auto",
                "n_estimators": 10, "random_state": 0},
            "RandomForestRegressor": {
                "bootstrap": True, "criterion": "mse",
                "oob_score": False, "max_features": "auto",
                "n_estimators": 10},
            "LinearSGDClassifier": {
                "loss": "hinge"},
            "LinearRegressor": {
                "fit_intercept": True}}
        fset_data, data = featurize.load_featureset(model.featureset.file_uri)
        model_data = MODELS_TYPE_DICT[model.type](**model_params[model.type])
        model_data.fit(fset_data, data['labels'])
        model.file_uri = pjoin('/tmp/', '{}.pkl'.format(str(uuid.uuid4())))
        joblib.dump(model_data, model.file_uri)
        DBSession().commit()
项目:rllab    作者:rll    | 项目源码 | 文件源码
def save_itr_params(itr, params):
    if _snapshot_dir:
        if _snapshot_mode == 'all':
            file_name = osp.join(_snapshot_dir, 'itr_%d.pkl' % itr)
            joblib.dump(params, file_name, compress=3)
        elif _snapshot_mode == 'last':
            # override previous params
            file_name = osp.join(_snapshot_dir, 'params.pkl')
            joblib.dump(params, file_name, compress=3)
        elif _snapshot_mode == "gap":
            if itr % _snapshot_gap == 0:
                file_name = osp.join(_snapshot_dir, 'itr_%d.pkl' % itr)
                joblib.dump(params, file_name, compress=3)
        elif _snapshot_mode == 'none':
            pass
        else:
            raise NotImplementedError
项目:rllab    作者:rll    | 项目源码 | 文件源码
def log_parameters(log_file, args, classes):
    log_params = {}
    for param_name, param_value in args.__dict__.items():
        if any([param_name.startswith(x) for x in list(classes.keys())]):
            continue
        log_params[param_name] = param_value
    for name, cls in classes.items():
        if isinstance(cls, type):
            params = get_all_parameters(cls, args)
            params["_name"] = getattr(args, name)
            log_params[name] = params
        else:
            log_params[name] = getattr(cls, "__kwargs", dict())
            log_params[name]["_name"] = cls.__module__ + "." + cls.__class__.__name__
    mkdir_p(os.path.dirname(log_file))
    with open(log_file, "w") as f:
        json.dump(log_params, f, indent=2, sort_keys=True)
项目:rllab    作者:rll    | 项目源码 | 文件源码
def log_parameters_lite(log_file, args):
    log_params = {}
    for param_name, param_value in args.__dict__.items():
        log_params[param_name] = param_value
    if args.args_data is not None:
        stub_method = pickle.loads(base64.b64decode(args.args_data))
        method_args = stub_method.kwargs
        log_params["json_args"] = dict()
        for k, v in list(method_args.items()):
            log_params["json_args"][k] = stub_to_json(v)
        kwargs = stub_method.obj.kwargs
        for k in ["baseline", "env", "policy"]:
            if k in kwargs:
                log_params["json_args"][k] = stub_to_json(kwargs.pop(k))
        log_params["json_args"]["algo"] = stub_to_json(stub_method.obj)
    mkdir_p(os.path.dirname(log_file))
    with open(log_file, "w") as f:
        json.dump(log_params, f, indent=2, sort_keys=True, cls=MyEncoder)
项目:maml_rl    作者:cbfinn    | 项目源码 | 文件源码
def save_itr_params(itr, params):
    if _snapshot_dir:
        if _snapshot_mode == 'all':
            file_name = osp.join(_snapshot_dir, 'itr_%d.pkl' % itr)
            joblib.dump(params, file_name, compress=3)
        elif _snapshot_mode == 'last':
            # override previous params
            file_name = osp.join(_snapshot_dir, 'params.pkl')
            joblib.dump(params, file_name, compress=3)
        elif _snapshot_mode == "gap":
            if itr % _snapshot_gap == 0:
                file_name = osp.join(_snapshot_dir, 'itr_%d.pkl' % itr)
                joblib.dump(params, file_name, compress=3)
        elif _snapshot_mode == 'none':
            pass
        else:
            raise NotImplementedError
项目:maml_rl    作者:cbfinn    | 项目源码 | 文件源码
def log_parameters(log_file, args, classes):
    log_params = {}
    for param_name, param_value in args.__dict__.items():
        if any([param_name.startswith(x) for x in list(classes.keys())]):
            continue
        log_params[param_name] = param_value
    for name, cls in classes.items():
        if isinstance(cls, type):
            params = get_all_parameters(cls, args)
            params["_name"] = getattr(args, name)
            log_params[name] = params
        else:
            log_params[name] = getattr(cls, "__kwargs", dict())
            log_params[name]["_name"] = cls.__module__ + "." + cls.__class__.__name__
    mkdir_p(os.path.dirname(log_file))
    with open(log_file, "w") as f:
        json.dump(log_params, f, indent=2, sort_keys=True)
项目:maml_rl    作者:cbfinn    | 项目源码 | 文件源码
def log_parameters_lite(log_file, args):
    log_params = {}
    for param_name, param_value in args.__dict__.items():
        log_params[param_name] = param_value
    if args.args_data is not None:
        stub_method = pickle.loads(base64.b64decode(args.args_data))
        method_args = stub_method.kwargs
        log_params["json_args"] = dict()
        for k, v in list(method_args.items()):
            log_params["json_args"][k] = stub_to_json(v)
        kwargs = stub_method.obj.kwargs
        for k in ["baseline", "env", "policy"]:
            if k in kwargs:
                log_params["json_args"][k] = stub_to_json(kwargs.pop(k))
        log_params["json_args"]["algo"] = stub_to_json(stub_method.obj)
    mkdir_p(os.path.dirname(log_file))
    with open(log_file, "w") as f:
        json.dump(log_params, f, indent=2, sort_keys=True, cls=MyEncoder)
项目:acl2017    作者:tttthomasssss    | 项目源码 | 文件源码
def save_vector_cache(vectors, vector_out_file, filetype='', **kwargs):
    logging.info("Saving {} vectors to cache {}".format(len(vectors),vector_out_file))
    if (vector_out_file.endswith('.dill') or filetype == 'dill'):
        with open(vector_out_file, 'wb') as data_file:
            dill.dump(vectors, data_file, protocol=kwargs.get('dill_protocol', 3))
    elif (vector_out_file.endswith('.joblib') or filetype == 'joblib'):
        joblib.dump(vectors, vector_out_file, compress=kwargs.get('joblib_compression', 3),
                    protocol=kwargs.get('joblib_protocol', 3))
    elif (vector_out_file.endswith('.sqlite') or filetype == 'sqlite'):
        autocommit = kwargs.pop('autocommit', True)
        if (isinstance(vectors, SqliteDict)):
            vectors.commit()
        else:
            with SqliteDict(vector_out_file, autocommit=autocommit) as data_file:
                for key, value in vectors.items():
                    data_file[key] = value

                if (not autocommit):
                    data_file.commit()
    else:
        raise NotImplementedError
项目:squeezeDet-hand    作者:fyhtea    | 项目源码 | 文件源码
def dump_caffemodel_weights():
  net = caffe.Net(args.prototxt_path, args.caffemodel_path, caffe.TEST)
  weights = {}
  n_layers = len(net.layers)
  for i in range(n_layers):
    layer_name = net._layer_names[i]
    layer = net.layers[i]
    layer_blobs = [o.data for o in layer.blobs]
    weights[layer_name] = layer_blobs
  joblib.dump(weights, args.caffe_weights_path)
项目:pybot    作者:spillai    | 项目源码 | 文件源码
def joblib_dump(item, path): 
    import joblib
    create_path_if_not_exists(path)
    joblib.dump(item, path)
项目:ISM2017    作者:ybayle    | 项目源码 | 文件源码
def create_model(clf_name, features, groundtruths, outdir, classifiers):
    begin = int(round(time.time() * 1000))
    utils.print_success("Starting " + clf_name)
    clf_dir = outdir + clf_name + "/"
    utils.create_dir(clf_dir)
    clf = classifiers[clf_name]
    clf.fit(features, groundtruths)
    joblib.dump(clf, clf_dir + clf_name + ".pkl")
    utils.print_info(clf_name + " done in " + str(int(round(time.time() * 1000)) - begin) + "ms")
项目:ProtScan    作者:gianlucacorrado    | 项目源码 | 文件源码
def save(self, model_name):
        """Save model to file."""
        joblib.dump(self, model_name, compress=1, protocol=2)
项目:keras-text    作者:raghakot    | 项目源码 | 文件源码
def dump(obj, file_name):
    if file_name.endswith('.json'):
        with open(file_name, 'w') as f:
            f.write(jsonpickle.dumps(obj))
        return

    if isinstance(obj, np.ndarray):
        np.save(file_name, obj)
        return

    # Using joblib instead of pickle because of http://bugs.python.org/issue11564
    joblib.dump(obj, file_name, protocol=pickle.HIGHEST_PROTOCOL)
项目:seglink    作者:bgshih    | 项目源码 | 文件源码
def dump_caffemodel_weights():
  net = caffe.Net(args.prototxt_path, args.caffemodel_path, caffe.TEST)
  weights = {}
  n_layers = len(net.layers)
  for i in range(n_layers):
    layer_name = net._layer_names[i]
    layer = net.layers[i]
    layer_blobs = [o.data for o in layer.blobs]
    weights[layer_name] = layer_blobs
  joblib.dump(weights, args.caffe_weights_path)
项目:third_person_im    作者:bstadie    | 项目源码 | 文件源码
def log_variant(log_file, variant_data):
    mkdir_p(os.path.dirname(log_file))
    if hasattr(variant_data, "dump"):
        variant_data = variant_data.dump()
    variant_json = stub_to_json(variant_data)
    with open(log_file, "w") as f:
        json.dump(variant_json, f, indent=2, sort_keys=True, cls=MyEncoder)
项目:deer    作者:VinF    | 项目源码 | 文件源码
def onEnd(self, agent):
        if (self._active == False):
            return

        bestIndex = np.argmax(self._validationScores)
        print("Best neural net obtained after {} epochs, with validation score {}".format(bestIndex+1, self._validationScores[bestIndex]))
        if self._testID != None:
            print("Test score of this neural net: {}".format(self._testScores[bestIndex]))

        try:
            os.mkdir("scores")
        except Exception:
            pass
        basename = "scores/" + self._filename
        joblib.dump({"vs": self._validationScores, "ts": self._testScores}, basename + "_scores.jldump")
项目:deer    作者:VinF    | 项目源码 | 文件源码
def dumpNetwork(self, fname, nEpoch=-1):
        """ Dump the network

        Parameters
        -----------
        fname : string
            Name of the file where the network will be dumped
        nEpoch : int
            Epoch number (Optional)
        """
        try:
            os.mkdir("nnets")
        except Exception:
            pass
        basename = "nnets/" + fname

        for f in os.listdir("nnets/"):
            if fname in f:
                os.remove("nnets/" + f)

        all_params = self._network.getAllParams()

        if (nEpoch>=0):
            joblib.dump(all_params, basename + ".epoch={}".format(nEpoch))
        else:
            joblib.dump(all_params, basename, compress=True)
项目:srep    作者:Answeror    | 项目源码 | 文件源码
def fit(self, train_data, eval_data, eval_metric='acc', **kargs):
        snapshot = kargs.pop('snapshot')
        self.clf.fit(*self._get_data_label(train_data))
        jb.dump(self.clf, snapshot + '-0001.params')

        if not isinstance(eval_metric, mx.metric.EvalMetric):
            eval_metric = mx.metric.create(eval_metric)
        data, label = self._get_data_label(eval_data)
        pred = self.clf.predict(data).astype(np.int64)
        prob = np.zeros((len(pred), pred.max() + 1))
        prob[np.arange(len(prob)), pred] = 1
        eval_metric.update([mx.nd.array(label)], [mx.nd.array(prob)])
        for name, val in eval_metric.get_name_value():
            logger.info('Epoch[0] Validation-{}={}', name, val)
项目:motif-classify    作者:macks22    | 项目源码 | 文件源码
def symbolize_signal(self, signal, parallel = None, n_jobs = -1):
        """
        Symbolize whole time-series signal to a sentence (vector of words),
        parallel can be {None, "ipython"}
        """
        window_index = self.sliding_window_index(len(signal))
        if parallel == None:
            return map(lambda wi: self.symbolize_window(signal[wi]), window_index)
        elif parallel == "ipython":
            ## too slow
            raise NotImplementedError("parallel parameter %s not supported" % parallel)
            #return self.iparallel_symbolize_signal(signal)
        elif parallel == "joblib":
            with tempfile.NamedTemporaryFile(delete=False) as f:
                tf = f.name
            print "save temp file at %s" % tf 
            tfiles = joblib.dump(signal, tf)
            xs = joblib.load(tf, "r")
            n_jobs = joblib.cpu_count() if n_jobs == -1 else n_jobs 
            window_index = list(window_index)
            batch_size = len(window_index) / n_jobs
            batches = chunk(window_index, batch_size)
            symbols = Parallel(n_jobs)(delayed(joblib_symbolize_window)(self, xs, batch) for batch in batches)
            for f in tfiles: os.unlink(f)
            return sum(symbols, [])
        else:
            raise NotImplementedError("parallel parameter %s not supported" % parallel)
项目:motif-classify    作者:macks22    | 项目源码 | 文件源码
def signal_to_paa_vector(self, signal, n_jobs = -1):
        window_index = self.sliding_window_index(len(signal))
        with tempfile.NamedTemporaryFile(delete=False) as f:
                tf = f.name
        print "save temp file at %s" % tf 
        tfiles = joblib.dump(signal, tf)
        xs = joblib.load(tf, "r")
        n_jobs = joblib.cpu_count() if n_jobs == -1 else n_jobs 
        window_index = list(window_index)
        batch_size = len(window_index) / n_jobs
        batches = chunk(window_index, batch_size)
        vecs = Parallel(n_jobs)(delayed(joblib_paa_window)(self, xs, batch) for batch in batches)
        for f in tfiles: os.unlink(f)
        return np.vstack(vecs)
项目:rllabplusplus    作者:shaneshixiang    | 项目源码 | 文件源码
def save_itr_params(itr, params):
    global _logger_info
    if _snapshot_dir:
        if _snapshot_mode == 'all':
            file_name = osp.join(_snapshot_dir, 'itr_%d.pkl' % itr)
            joblib.dump(params, file_name, compress=3)
        elif _snapshot_mode == 'last':
            # override previous params
            file_name = osp.join(_snapshot_dir, 'params.pkl')
            joblib.dump(params, file_name, compress=3)
        elif _snapshot_mode == 'last_best':
            # saves best and last params
            last_file_name = osp.join(_snapshot_dir, 'params.pkl')
            joblib.dump(params, last_file_name, compress=3)
            _logger_info["lastReward"] = get_last_tabular("AverageReturn")
            _logger_info["lastItr"] = get_last_tabular("Iteration")
            if "bestReward" not in _logger_info or _logger_info["bestReward"] < _logger_info["lastReward"]:
                best_file_name = osp.join(_snapshot_dir, 'params_best.pkl')
                shutil.copy(last_file_name, best_file_name)
                _logger_info["bestReward"] = _logger_info["lastReward"]
                _logger_info["bestItr"] = _logger_info["lastItr"]
        elif _snapshot_mode == 'last_all_best':
            # saves last and all best params
            last_file_name = osp.join(_snapshot_dir, 'params.pkl')
            joblib.dump(params, last_file_name, compress=3)
            _logger_info["lastReward"] = get_last_tabular("AverageReturn")
            _logger_info["lastItr"] = get_last_tabular("Iteration")
            if "bestReward" not in _logger_info or _logger_info["bestReward"] < _logger_info["lastReward"]:
                best_file_name = osp.join(_snapshot_dir, 'params_best_%08d.pkl' % itr)
                shutil.copy(last_file_name, best_file_name)
                _logger_info["bestReward"] = _logger_info["lastReward"]
                _logger_info["bestItr"] = _logger_info["lastItr"]
        elif _snapshot_mode == "gap":
            if itr % _snapshot_gap == 0:
                file_name = osp.join(_snapshot_dir, 'itr_%d.pkl' % itr)
                joblib.dump(params, file_name, compress=3)
        elif _snapshot_mode == 'none':
            pass
        else:
            raise NotImplementedError
项目:rllabplusplus    作者:shaneshixiang    | 项目源码 | 文件源码
def log_variant(log_file, variant_data):
    mkdir_p(os.path.dirname(log_file))
    if hasattr(variant_data, "dump"):
        variant_data = variant_data.dump()
    variant_json = stub_to_json(variant_data)
    with open(log_file, "w") as f:
        json.dump(variant_json, f, indent=2, sort_keys=True, cls=MyEncoder)
项目:AutoML-Challenge    作者:postech-mlg-exbrain    | 项目源码 | 文件源码
def _split_and_dump(self, X, y, valid_X, valid_y):
        if not hasattr(self, '_dm'):
            raise ValueError("It should be called after the dumpmanager _dm is set")

        if self.resampling == 'cv':
            pass
        elif self.resampling == 'holdout':
            if not self._has_valid_data:
                data_size = y.shape[0]
                if data_size >= 100000:
                    valid_ratio = 0.3
                elif 15000 <= data_size < 100000:
                    valid_ratio = 0.2
                else:
                    valid_ratio = 0.15
                valid_size = int(data_size * valid_ratio)
                X, valid_X = X[valid_size:], X[:valid_size]
                y, valid_y = y[valid_size:], y[:valid_size]
        else:
            raise NotImplementedError()

        pkl = {"resampling": self.resampling,
               "X": X, "y": y,
               "valid_X": valid_X, "valid_y": valid_y}

        datafile = os.path.join(self._dm.dir, "data.pkl")
        joblib.dump(pkl, datafile, protocol=-1)

        self._datafile = datafile
        return datafile
项目:SRLF    作者:Fritz449    | 项目源码 | 文件源码
def dump_object(data):
    # converts whatever to string
    s = BytesIO()
    joblib.dump(data, s)

    return s.getvalue()
项目:kaggle-lung-cancer    作者:mdai    | 项目源码 | 文件源码
def process_study(study_id, annotations, out_dir, nstack):
    volumes_metadata = isotropic_volumes_metadata[study_id]
    isometric_volume = np.load('../data_proc/stage1/isotropic_volumes_1mm/{}.npy'.format(study_id))
    mean = np.mean(isometric_volume).astype(np.float32)
    std = np.std(isometric_volume).astype(np.float32)
    resize_factor = np.divide(volumes_metadata['volume_resampled_shape'], volumes_metadata['volume_shape'])

    coords_list = []
    for a in annotations:
        d = a['data']
        z = int(round(resize_factor[0] * a['sliceNum']))
        y0 = resize_factor[1] * d['y']
        y1 = resize_factor[1] * (d['y'] + d['height'])
        x0 = resize_factor[2] * d['x']
        x1 = resize_factor[2] * (d['x'] + d['width'])
        coords_list.append((z, y0, y1, x0, x1))

    samples = []
    for coords in coords_list:
        z, y0, y1, x0, x1 = coords
        for i in range(40):
            sample_id = uuid4()
            rand_y0 = max(0, int(round(y0 - random.randint(0, 32))))
            rand_y1 = min(isometric_volume.shape[1], int(round(y1 + random.randint(0, 32))))
            rand_x0 = max(0, int(round(x0 - random.randint(0, 32))))
            rand_x1 = min(isometric_volume.shape[2], int(round(x1 + random.randint(0, 32))))
            patch = []
            for zi in range(nstack):
                patch.append(resize(isometric_volume[z+zi, rand_y0:rand_y1, rand_x0:rand_x1], [32, 32],
                                    mode='edge', clip=True, preserve_range=True))
            patch = np.array(patch, dtype=np.float32)
            patch = (patch - mean) / (std + 1e-7)
            patch = np.moveaxis(patch, 0, 2)
            bb_x = (x0 - rand_x0) / (rand_x1 - rand_x0)
            bb_y = (y0 - rand_y0) / (rand_y1 - rand_y0)
            bb_w = (x1 - x0) / (rand_x1 - rand_x0)
            bb_h = (y1 - y0) / (rand_y1 - rand_y0)
            samples.append((patch, bb_x, bb_y, bb_w, bb_h))

    joblib.dump(samples, os.path.join(out_dir, 'samples', '{}.pkl'.format(study_id)))
    return len(samples)
项目:CIKM2017    作者:heliarmk    | 项目源码 | 文件源码
def agg(file_name,store_file):

    datas = joblib.load(file_name)
    new_datas = []

    for data in datas:
        new_datas.append(data)
        new_datas.append({"input":np.flip(data["input"],axis=2),"label":data["label"]})
        new_datas.append({"input":np.flip(data["input"],axis=3),"label":data["label"]})
        #new_datas.append({"input":np.rot90(m=data["input"],k=1,axes=(2,3)),"label":data["label"]})
        #new_datas.append({"input":np.rot90(m=data["input"],k=2,axes=(2,3)),"label":data["label"]})
        #new_datas.append({"input":np.rot90(m=data["input"],k=3,axes=(2,3)),"label":data["label"]})

    joblib.dump(value=new_datas,filename=store_file,compress=3)
项目:CIKM2017    作者:heliarmk    | 项目源码 | 文件源码
def slice_data(filename):
    data = joblib.load(filename=filename)
    for idx, i in enumerate(data):
        data[idx]["input"] = np.delete(data[idx]["input"],[3],axis=1)
        data[idx]["input"] = data[idx]["input"][:,:,46:55,46:55]
    name, suf = os.path.splitext(filename)
    outputfilename = name + "del_height_no.4_slice_7x7.pkl"
    joblib.dump(value=data, filename=outputfilename)
项目:CIKM2017    作者:heliarmk    | 项目源码 | 文件源码
def main():
    # Get the data.
    trains = joblib.load("../data/CIKM2017_train/train_Imp_3x3.pkl")
    #testa_set = joblib.load("../data/CIKM2017_testA/testA_Imp_3x3_del_height_no.4.pkl")
    #testa_x = []

    #for item in testa_set:
    #    testa_x.append(item["input"])

    #testa_x = np.asarray(testa_x, dtype=np.int16).transpose((0,1,3,4,2))
    train_x, train_y, train_class = sample(trains)
    '''
    for i in range(10):
        np.random.shuffle(data_set)
    valid_data_num = int(len(data_set) / 10) #get 10% data for validation
    for i in range(10):
        valid_set = data_set[i * valid_data_num : (i+1) * valid_data_num ]
        train_set = data_set[0: i*valid_data_num]
        train_set.extend(data_set[(i+1)*valid_data_num:])
        train_out, train_mean, train_std = preprocessing(train_set, 0, 0, True )
        valid_out = preprocessing(valid_set, train_mean, train_std)

        testa_out = preprocessing(testa_set, train_mean, train_std)

        convert_to(train_out, "train_Imp_3x3_resample_normalization_"+str(i)+"_fold", is_test=False)
        convert_to(valid_out, "valid_Imp_3x3_resample_normalization_"+str(i)+"_fold", is_test=False)
        convert_to(testa_out, "testA_Imp_3x3_normalization_"+str(i)+"_fold", is_test=True)
    #joblib.dump(value=data_set, filename="../data/CIKM2017_train/train_Imp_3x3_classified_del_height_no.4.pkl",compress=3)
    '''
    h5fname = "../data/CIKM2017_train/train_Imp_3x3.h5"
    import h5py
    "write file"
    with h5py.File(h5fname, "w") as f:
        #f.create_dataset(name="testa_set_x", shape=testa_x.shape, data=testa_x, dtype=testa_x.dtype, compression="lzf", chunks=True)
        f.create_dataset(name="train_set_x", shape=train_x.shape, data=train_x, dtype=train_x.dtype, compression="lzf", chunks=True)
        f.create_dataset(name="train_set_y", shape=train_y.shape, data=train_y, dtype=train_y.dtype, compression="lzf", chunks=True)
        f.create_dataset(name="train_set_class", shape=train_class.shape, data=train_class, dtype=train_class.dtype, compression="lzf", chunks=True)

    return
项目:Asynchronous-RL-agent    作者:Fritz449    | 项目源码 | 文件源码
def dump_object(data):
    # converts whatever to string
    s = BytesIO()
    joblib.dump(data, s)
    return s.getvalue()
项目:senti    作者:stevenxxiu    | 项目源码 | 文件源码
def test_pickle(self):
        joblib.dump(CachedIterable(self.iterator(), 3), 'output')
        self.assertListEqual(list(joblib.load('output')), list(range(20)))
项目:senti    作者:stevenxxiu    | 项目源码 | 文件源码
def main():
    os.chdir('data/google')
    model = Word2Vec.load_word2vec_format('GoogleNews-vectors-negative300.bin', binary=True, norm_only=False)
    for v in model.vocab.values():
        v.sample_int = 0
    ts = list(model.vocab.items())
    ts.sort(key=lambda t: t[1].index)
    model.vocab = OrderedDict(ts)
    joblib.dump(model, 'GoogleNews-vectors-negative300.pickle')
项目:vgg16.tf    作者:bgshih    | 项目源码 | 文件源码
def dump_caffemodel_weights():
  net = caffe.Net(args.prototxt_path, args.caffemodel_path, caffe.TEST)
  weights = {}
  n_layers = len(net.layers)
  for i in range(n_layers):
    layer_name = net._layer_names[i]
    layer = net.layers[i]
    layer_blobs = [o.data for o in layer.blobs]
    weights[layer_name] = layer_blobs
  joblib.dump(weights, args.caffe_weights_path)
项目:xcessiv    作者:reiinakano    | 项目源码 | 文件源码
def save(self, filepath):
        joblib.dump(self, filepath, 3)
项目:soinn    作者:fukatani    | 项目源码 | 文件源码
def save(self, dumpfile='soinn.dump'):
        import joblib
        joblib.dump(self, dumpfile, compress=True, protocol=0)
项目:gail-driver    作者:sisl    | 项目源码 | 文件源码
def log_variant(log_file, variant_data):
    mkdir_p(os.path.dirname(log_file))
    if hasattr(variant_data, "dump"):
        variant_data = variant_data.dump()
    variant_json = stub_to_json(variant_data)
    with open(log_file, "w") as f:
        json.dump(variant_json, f, indent=2, sort_keys=True, cls=MyEncoder)
项目:TensorArtist    作者:vacancy    | 项目源码 | 文件源码
def dump(path, content, method=None, py_prefix='', py_suffix='', text_mode='w'):
    if method is None:
        method = _infer_method(path)

    assert_instance(method, IOMethod)

    path_origin = path
    if method != IOMethod.TEXT or text_mode == 'w':
        path += '.tmp'

    if method == IOMethod.PICKLE:
        with open(path, 'wb') as f:
            pickle.dump(content, f, protocol=pickle.HIGHEST_PROTOCOL)
    elif method == IOMethod.PICKLE_GZ:
        with gzip.open(path, 'wb') as f:
            pickle.dump(content, f, protocol=pickle.HIGHEST_PROTOCOL)
    elif method == IOMethod.NUMPY:
        joblib.dump(content, path)
    elif method == IOMethod.NUMPY_RAW:
        with open(path, 'wb') as f:
            content.dump(f)
    elif method == IOMethod.TEXT:
        with open(path, text_mode) as f:
            if type(content) in (list, tuple):
                f.writelines(content)
            else:
                f.write(str(content))
    elif method == IOMethod.BINARY:
        with open(path, 'wb') as f:
            f.write(content)
    else:
        raise ValueError('Unsupported dumping method: {}', method)

    if method != IOMethod.TEXT or text_mode == 'w':
        os.rename(path, path_origin)

    return path_origin
项目:quantgov    作者:QuantGov    | 项目源码 | 文件源码
def save(self, path):
        """
        Use joblib to pickle the object.

        Arguments:
            path: an open file object or string holding the path to where the
                object should be saved
        """
        jl.dump(self, path)
项目:cLoops    作者:YaqiangCao    | 项目源码 | 文件源码
def txt2jd(f):
    """
    Dump the np.ndarray using joblib.dump for fast access.
    """
    data = []
    for line in open(f):
        line = line.split("\n")[0].split("\t")
        data.append(map(int, line))
    data = np.array(data)
    joblib.dump(data, f.replace(".txt", ".jd"))
    os.remove(f)
    return f.replace(".txt", ".jd")