Python sklearn.base 模块,BaseEstimator() 实例源码

我们从Python开源项目中,提取了以下21个代码示例,用于说明如何使用sklearn.base.BaseEstimator()

项目:moabb    作者:NeuroTechX    | 项目源码 | 文件源码
def __init__(self, datasets, pipelines):
        """init"""
        # check dataset
        if not isinstance(datasets, list):
            if isinstance(datasets, BaseDataset):
                datasets = [datasets]
            else:
                raise(ValueError("datasets must be a list or a dataset instance"))

        for dataset in datasets:
            if not(isinstance(dataset, BaseDataset)):
                raise(ValueError("datasets must only contains dataset instance"))

        self.datasets = datasets

        # check pipelines
        if not isinstance(pipelines, dict):
            raise(ValueError("pipelines must be a dict or a Pipeline instance"))

        for name, pipeline in pipelines.items():
            if not(isinstance(pipeline, BaseEstimator)):
                raise(ValueError("pipelines must only contains Pipelines instance"))
        self.pipelines = pipelines
项目:yellowbrick    作者:DistrictDataLabs    | 项目源码 | 文件源码
def is_estimator(model):
    """
    Determines if a model is an estimator using issubclass and isinstance.

    Parameters
    ----------
    estimator : class or instance
        The object to test if it is a Scikit-Learn clusterer, especially a
        Scikit-Learn estimator or Yellowbrick visualizer
    """
    if inspect.isclass(model):
        return issubclass(model, BaseEstimator)

    return isinstance(model, BaseEstimator)

# Alias for closer name to isinstance and issubclass
项目:yellowbrick    作者:DistrictDataLabs    | 项目源码 | 文件源码
def test_subclass(self):
        """
        Assert the feature visualizer is in its rightful place
        """
        visualizer = FeatureVisualizer()
        self.assertIsInstance(visualizer, TransformerMixin)
        self.assertIsInstance(visualizer, BaseEstimator)
        self.assertIsInstance(visualizer, Visualizer)

    # def test_interface(self):
    #     """
    #     Test the feature visualizer interface
    #     """
    #
    #     visualizer = FeatureVisualizer()
    #     with self.assertRaises(NotImplementedError):
    #         visualizer.poof()
项目:yellowbrick    作者:DistrictDataLabs    | 项目源码 | 文件源码
def test_subclass(self):
        """
        Assert the text visualizer is subclassed correctly 
        """
        visualizer = TextVisualizer()
        self.assertIsInstance(visualizer, TransformerMixin)
        self.assertIsInstance(visualizer, BaseEstimator)
        self.assertIsInstance(visualizer, Visualizer)

    # def test_interface(self):
    #     """
    #     Test the feature visualizer interface
    #     """
    #
    #     visualizer = TextVisualizer()
    #     with self.assertRaises(NotImplementedError):
    #         visualizer.poof()
项目:Parallel-SGD    作者:angadgill    | 项目源码 | 文件源码
def test_sample_weight_adaboost_regressor():
    """
    AdaBoostRegressor should work without sample_weights in the base estimator

    The random weighted sampling is done internally in the _boost method in
    AdaBoostRegressor.
    """
    class DummyEstimator(BaseEstimator):

        def fit(self, X, y):
            pass

        def predict(self, X):
            return np.zeros(X.shape[0])

    boost = AdaBoostRegressor(DummyEstimator(), n_estimators=3)
    boost.fit(X, y_regr)
    assert_equal(len(boost.estimator_weights_), len(boost.estimator_errors_))
项目:stacker    作者:bamine    | 项目源码 | 文件源码
def __init__(self, task: Task, models: List[BaseEstimator]):
        super().__init__(task)
        self.models = models
项目:stacker    作者:bamine    | 项目源码 | 文件源码
def __init__(self, model: BaseEstimator, task: Task, space: Space, scorer: Scorer, opt_logger: OptimizationLogger):
        self.model = model
        self.task = task
        self.space = space
        self.scorer = scorer
        self.opt_logger = opt_logger
        self.best = None
项目:uncover-ml    作者:GeoscienceAustralia    | 项目源码 | 文件源码
def _check_sklearn_model(model):
    if not (isinstance(model, BaseEstimator) and
            isinstance(model, RegressorMixin)):
        raise RuntimeError('Needs to supply an instance of a scikit-learn '
                           'compatible regression class.')
项目:dask-searchcv    作者:dask    | 项目源码 | 文件源码
def normalize_estimator(est):
    """Normalize an estimator.

    Note: Since scikit-learn requires duck-typing, but not sub-typing from
    ``BaseEstimator``, we sometimes need to call this function directly."""
    return type(est).__name__, normalize_token(est.get_params())
项目:healthcareai-py    作者:HealthCatalyst    | 项目源码 | 文件源码
def predict_regression(x_test, trained_estimator):
    """
    Given feature data and a trained estimator, return a regression prediction

    Args:
        x_test: 
        trained_estimator (sklearn.base.BaseEstimator): a trained scikit-learn estimator

    Returns:
        a prediction
    """
    validate_estimator(trained_estimator)
    prediction = trained_estimator.predict(x_test)
    return prediction
项目:healthcareai-py    作者:HealthCatalyst    | 项目源码 | 文件源码
def predict_classification(x_test, trained_estimator):
    """
    Given feature data and a trained estimator, return a classification prediction

    Args:
        x_test: 
        trained_estimator (sklearn.base.BaseEstimator): a trained scikit-learn estimator

    Returns:
        a prediction
    """
    validate_estimator(trained_estimator)
    prediction = np.squeeze(trained_estimator.predict_proba(x_test)[:, 1])
    return prediction
项目:healthcareai-py    作者:HealthCatalyst    | 项目源码 | 文件源码
def validate_estimator(possible_estimator):
    """
    Given an object, raise an error if it is not a scikit-learn BaseEstimator

    Args:
        possible_estimator (object): Object of any type.

    Returns:
        True or raises error - the True is used only for testing
    """
    if not issubclass(type(possible_estimator), BaseEstimator):
        raise HealthcareAIError(
            'Predictions require an estimator. You passed in {}, which is of type: {}'.format(possible_estimator,
                                                                                              type(possible_estimator)))
    return True
项目:document-qa    作者:allenai    | 项目源码 | 文件源码
def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, np.dtype):
            return str(obj)
        elif isinstance(obj, np.floating):
            return float(obj)
        elif isinstance(obj, np.bool_):
            return bool(obj)
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        elif isinstance(obj, BaseEstimator):  # handle sklearn estimators
            return Configuration(obj.__class__.__name__, 0, obj.get_params())
        elif isinstance(obj, Configuration):
            if "version" in obj.params or "name" in obj.params:
                raise ValueError()
            out = OrderedDict()
            out["name"] = obj.name
            if obj.version != 0:
                out["version"] = obj.version
            out.update(obj.params)
            return out
        elif isinstance(obj, Configurable):
            return obj.get_config()
        elif isinstance(obj, set):
            return sorted(obj)  # Ensure deterministic order
        else:
            try:
                return super().default(obj)
            except TypeError:
                return str(obj)
项目:highdimensional-decision-boundary-plot    作者:tmadl    | 项目源码 | 文件源码
def setclassifier(self, estimator=KNeighborsClassifier(n_neighbors=10)):
        """Assign classifier for which decision boundary should be plotted.

        Parameters
        ----------
        estimator : BaseEstimator instance, optional (default=KNeighborsClassifier(n_neighbors=10)).
            Classifier for which the decision boundary should be plotted. Must have
            probability estimates enabled (i.e. estimator.predict_proba must work).
            Make sure it is possible for probability estimates to get close to 0.5
            (more specifically, as close as specified by acceptance_threshold).
        """
        self.classifier = estimator
项目:ibex    作者:atavory    | 项目源码 | 文件源码
def _generate_bases_test(est, pd_est):
    def test(self):
        self.assertTrue(isinstance(pd_est, FrameMixin), pd_est)
        self.assertFalse(isinstance(est, FrameMixin))
        self.assertTrue(isinstance(pd_est, base.BaseEstimator))
        try:
            mixins = [
                base.ClassifierMixin,
                base.ClusterMixin,
                base.BiclusterMixin,
                base.TransformerMixin,
                base.DensityMixin,
                base.MetaEstimatorMixin,
                base.ClassifierMixin,
                base.RegressorMixin]
        except:
            if _sklearn_ver > 17:
                raise
            mixins = [
                base.ClassifierMixin,
                base.ClusterMixin,
                base.BiclusterMixin,
                base.TransformerMixin,
                base.MetaEstimatorMixin,
                base.ClassifierMixin,
                base.RegressorMixin]
        for mixin in mixins:
            self.assertEqual(
                isinstance(pd_est, mixin),
                isinstance(est, mixin),
                mixin)

    return test
项目:Parallel-SGD    作者:angadgill    | 项目源码 | 文件源码
def test_check_estimator():
    # tests that the estimator actually fails on "bad" estimators.
    # not a complete test of all checks, which are very extensive.

    # check that we have a set_params and can clone
    msg = "it does not implement a 'get_params' methods"
    assert_raises_regex(TypeError, msg, check_estimator, object)
    # check that we have a fit method
    msg = "object has no attribute 'fit'"
    assert_raises_regex(AttributeError, msg, check_estimator, BaseEstimator)
    # check that fit does input validation
    msg = "TypeError not raised by fit"
    assert_raises_regex(AssertionError, msg, check_estimator, BaseBadClassifier)
    # check that predict does input validation (doesn't accept dicts in input)
    msg = "Estimator doesn't check for NaN and inf in predict"
    assert_raises_regex(AssertionError, msg, check_estimator, NoCheckinPredict)
    # check for sparse matrix input handling
    name = NoSparseClassifier.__name__
    msg = "Estimator " + name + " doesn't seem to fail gracefully on sparse data"
    # the check for sparse input handling prints to the stdout,
    # instead of raising an error, so as not to remove the original traceback.
    # that means we need to jump through some hoops to catch it.
    old_stdout = sys.stdout
    string_buffer = StringIO()
    sys.stdout = string_buffer
    try:
        check_estimator(NoSparseClassifier)
    except:
        pass
    finally:
        sys.stdout = old_stdout
    assert_true(msg in string_buffer.getvalue())

    # doesn't error on actual estimator
    check_estimator(AdaBoostClassifier)
    check_estimator(MultiTaskElasticNet)
项目:hh-page-classifier    作者:TeamHG-Memex    | 项目源码 | 文件源码
def get_attributes(obj):
    if isinstance(obj, TfidfVectorizer):
        return get_tfidf_attributes(obj)
    elif isinstance(obj, XGBClassifier):
        return pickle.dumps(obj)
    elif isinstance(obj, BaseEstimator):
        return {attr: getattr(obj, attr) for attr in dir(obj)
                if not attr.startswith('_') and attr.endswith('_')
                and attr not in skip_attributes}
    elif obj is not None:
        raise TypeError(type(obj))
项目:hh-page-classifier    作者:TeamHG-Memex    | 项目源码 | 文件源码
def set_attributes(parent, field, attributes):
    obj = getattr(parent, field)
    if isinstance(obj, TfidfVectorizer):
        set_ifidf_attributes(obj, attributes)
    elif isinstance(obj, XGBClassifier):
        setattr(parent, field, pickle.loads(attributes))
    elif isinstance(obj, BaseEstimator):
        for k, v in attributes.items():
            try:
                setattr(obj, k, v)
            except AttributeError:
                raise AttributeError(
                    'can\'t set attribute {} on {}'.format(k, obj))
    elif obj is not None:
        raise TypeError(type(obj))
项目:extract    作者:dblalock    | 项目源码 | 文件源码
def wrap(func):
        return FuncWrapper(func)

    # BaseEstimator figures out what our params are based on the signature
    # of init, so we have to list them all here (though in this case it's
    # just the funciton we're wrapping)
项目:extract    作者:dblalock    | 项目源码 | 文件源码
def fit(self, X, y=None, **params):
        # have to load dataset here, not in init, to
        # work with BaseEstimator cloning
        self.tsList_ = loadDatasets(self.datasetName, seed=self.seed,
            whichExamples=self.whichExamples, instancesPerTs=self.instancesPerTs,
            minNumInstances=self.minNumInstances,
            maxNumInstances=self.maxNumInstances,
            cropDataLength=self.cropDataLength)

        return self
项目:ibex    作者:atavory    | 项目源码 | 文件源码
def run(self):
        from distutils.dir_util import copy_tree

        import sklearn
        from sklearn import base
        from jinja2 import Template

        class_template = Template(
            open(os.path.join('docs/source/api_class.rst.jinja2')).read())

        sklearn_modules =  {}
        for mod_name in sklearn.__all__:
            if mod_name.startswith('_'):
                continue

            try:
                orig = __import__('sklearn.%s' % mod_name, fromlist=[''])
            except:
                for _ in range(20):
                    print('failed to import %s' % orig)
                # Tmp Ami
                continue

            sklearn_modules[mod_name] = []

            for name in dir(orig):
                c = getattr(orig, name)
                try:
                    if not issubclass(c, base.BaseEstimator):
                        continue
                except TypeError:
                    continue
                sklearn_modules[mod_name].append('ibex.sklearn.%s.%s' % (mod_name, name))
                content = class_template.render(
                    class_name=name,
                    full_class_name='ibex.sklearn.%s.%s' % (mod_name, name))
                f_name = 'docs/source/api_ibex_sklearn_%s_%s.rst' % (mod_name, name.lower())
                open(f_name, 'w').write(content)

        class_template = Template(
            open(os.path.join('docs/source/api.rst.jinja2')).read())
        content = class_template.render(
            sklearn_modules=sklearn_modules)
        f_name = 'docs/source/api.rst'
        open(f_name, 'w').write(content)

        run_str = 'make text'
        subprocess.call(run_str.split(' '), cwd='docs')

        run_str = 'make html'
        if not self.reduced_checks:
            run_str += ' spelling lint linkcheck'
        subprocess.check_call(run_str.split(' '), cwd='docs')