Python pytest 模块,param() 实例源码

我们从Python开源项目中,提取了以下18个代码示例,用于说明如何使用pytest.param()

项目:pyro    作者:uber    | 项目源码 | 文件源码
def test_iter_discrete_traces_scalar(graph_type):
    pyro.clear_param_store()

    def model():
        p = pyro.param("p", Variable(torch.Tensor([0.05])))
        ps = pyro.param("ps", Variable(torch.Tensor([0.1, 0.2, 0.3, 0.4])))
        x = pyro.sample("x", dist.Bernoulli(p))
        y = pyro.sample("y", dist.Categorical(ps, one_hot=False))
        return dict(x=x, y=y)

    traces = list(iter_discrete_traces(graph_type, model))

    p = pyro.param("p").data
    ps = pyro.param("ps").data
    assert len(traces) == 2 * len(ps)

    for scale, trace in traces:
        x = trace.nodes["x"]["value"].data.long().view(-1)[0]
        y = trace.nodes["y"]["value"].data.long().view(-1)[0]
        expected_scale = Variable(torch.Tensor([[1 - p[0], p[0]][x] * ps[y]]))
        assert_equal(scale, expected_scale)
项目:pyro    作者:uber    | 项目源码 | 文件源码
def test_iter_discrete_traces_nan(enum_discrete, trace_graph):
    pyro.clear_param_store()

    def model():
        p = Variable(torch.Tensor([0.0, 0.5, 1.0]))
        pyro.sample("z", dist.Bernoulli(p))

    def guide():
        p = pyro.param("p", Variable(torch.Tensor([0.0, 0.5, 1.0]), requires_grad=True))
        pyro.sample("z", dist.Bernoulli(p))

    Elbo = TraceGraph_ELBO if trace_graph else Trace_ELBO
    elbo = Elbo(enum_discrete=enum_discrete)
    with xfail_if_not_implemented():
        loss = elbo.loss(model, guide)
        assert isinstance(loss, float) and not math.isnan(loss), loss
        loss = elbo.loss_and_grads(model, guide)
        assert isinstance(loss, float) and not math.isnan(loss), loss


# A simple Gaussian mixture model, with no vectorization.
项目:pyro    作者:uber    | 项目源码 | 文件源码
def finite_difference(eval_loss, delta=0.1):
    """
    Computes finite-difference approximation of all parameters.
    """
    params = pyro.get_param_store().get_all_param_names()
    assert params, "no params found"
    grads = {name: Variable(torch.zeros(pyro.param(name).size())) for name in params}
    for name in sorted(params):
        value = pyro.param(name).data
        for index in itertools.product(*map(range, value.size())):
            center = value[index]
            value[index] = center + delta
            pos = eval_loss()
            value[index] = center - delta
            neg = eval_loss()
            value[index] = center
            grads[name][index] = (pos - neg) / (2 * delta)
    return grads
项目:pyro    作者:uber    | 项目源码 | 文件源码
def test_gmm_elbo_gradient(model, guide, enum_discrete, trace_graph):
    pyro.clear_param_store()
    num_particles = 4000
    data = Variable(torch.Tensor([-1, 1]))

    print("Computing gradients using surrogate loss")
    Elbo = TraceGraph_ELBO if trace_graph else Trace_ELBO
    elbo = Elbo(enum_discrete=enum_discrete,
                num_particles=(1 if enum_discrete else num_particles))
    with xfail_if_not_implemented():
        elbo.loss_and_grads(model, guide, data)
    params = sorted(pyro.get_param_store().get_all_param_names())
    assert params, "no params found"
    actual_grads = {name: pyro.param(name).grad.clone() for name in params}

    print("Computing gradients using finite difference")
    elbo = Trace_ELBO(num_particles=num_particles)
    expected_grads = finite_difference(lambda: elbo.loss(model, guide, data))

    for name in params:
        print("{} {}{}{}".format(name, "-" * 30, actual_grads[name].data,
                                 expected_grads[name].data))
    assert_equal(actual_grads, expected_grads, prec=0.5)
项目:treecat    作者:posterior    | 项目源码 | 文件源码
def xfail_param(*args, **kwargs):
    return pytest.param(*args, marks=pytest.mark.xfail(**kwargs))
项目:pyro    作者:uber    | 项目源码 | 文件源码
def test_iter_discrete_traces_vector(graph_type):
    pyro.clear_param_store()

    def model():
        p = pyro.param("p", Variable(torch.Tensor([[0.05], [0.15]])))
        ps = pyro.param("ps", Variable(torch.Tensor([[0.1, 0.2, 0.3, 0.4],
                                                     [0.4, 0.3, 0.2, 0.1]])))
        x = pyro.sample("x", dist.Bernoulli(p))
        y = pyro.sample("y", dist.Categorical(ps, one_hot=False))
        assert x.size() == (2, 1)
        assert y.size() == (2, 1)
        return dict(x=x, y=y)

    traces = list(iter_discrete_traces(graph_type, model))

    p = pyro.param("p").data
    ps = pyro.param("ps").data
    assert len(traces) == 2 * ps.size(-1)

    for scale, trace in traces:
        x = trace.nodes["x"]["value"].data.squeeze().long()[0]
        y = trace.nodes["y"]["value"].data.squeeze().long()[0]
        expected_scale = torch.exp(dist.Bernoulli(p).log_pdf(x) *
                                   dist.Categorical(ps, one_hot=False).log_pdf(y))
        expected_scale = expected_scale.data.view(-1)[0]
        assert_equal(scale, expected_scale)
项目:pyro    作者:uber    | 项目源码 | 文件源码
def gmm_model(data, verbose=False):
    p = pyro.param("p", Variable(torch.Tensor([0.3]), requires_grad=True))
    sigma = pyro.param("sigma", Variable(torch.Tensor([1.0]), requires_grad=True))
    mus = Variable(torch.Tensor([-1, 1]))
    for i in pyro.irange("data", len(data)):
        z = pyro.sample("z_{}".format(i), dist.Bernoulli(p))
        assert z.size() == (1,)
        z = z.long().data[0]
        if verbose:
            print("M{} z_{} = {}".format("  " * i, i, z))
        pyro.observe("x_{}".format(i), dist.Normal(mus[z], sigma), data[i])
项目:pyro    作者:uber    | 项目源码 | 文件源码
def gmm_guide(data, verbose=False):
    for i in pyro.irange("data", len(data)):
        p = pyro.param("p_{}".format(i), Variable(torch.Tensor([0.6]), requires_grad=True))
        z = pyro.sample("z_{}".format(i), dist.Bernoulli(p))
        assert z.size() == (1,)
        z = z.long().data[0]
        if verbose:
            print("G{} z_{} = {}".format("  " * i, i, z))
项目:pyro    作者:uber    | 项目源码 | 文件源码
def gmm_batch_guide(data):
    with pyro.iarange("data", len(data)) as batch:
        n = len(batch)
        ps = pyro.param("ps", Variable(torch.ones(n, 1) * 0.6, requires_grad=True))
        ps = torch.cat([ps, 1 - ps], dim=1)
        z = pyro.sample("z", dist.Categorical(ps))
        assert z.size() == (n, 2)
项目:pyro    作者:uber    | 项目源码 | 文件源码
def test_bern_elbo_gradient(enum_discrete, trace_graph):
    pyro.clear_param_store()
    num_particles = 2000

    def model():
        p = Variable(torch.Tensor([0.25]))
        pyro.sample("z", dist.Bernoulli(p))

    def guide():
        p = pyro.param("p", Variable(torch.Tensor([0.5]), requires_grad=True))
        pyro.sample("z", dist.Bernoulli(p))

    print("Computing gradients using surrogate loss")
    Elbo = TraceGraph_ELBO if trace_graph else Trace_ELBO
    elbo = Elbo(enum_discrete=enum_discrete,
                num_particles=(1 if enum_discrete else num_particles))
    with xfail_if_not_implemented():
        elbo.loss_and_grads(model, guide)
    params = sorted(pyro.get_param_store().get_all_param_names())
    assert params, "no params found"
    actual_grads = {name: pyro.param(name).grad.clone() for name in params}

    print("Computing gradients using finite difference")
    elbo = Trace_ELBO(num_particles=num_particles)
    expected_grads = finite_difference(lambda: elbo.loss(model, guide))

    for name in params:
        print("{} {}{}{}".format(name, "-" * 30, actual_grads[name].data,
                                 expected_grads[name].data))
    assert_equal(actual_grads, expected_grads, prec=0.1)
项目:py-evm    作者:ethereum    | 项目源码 | 文件源码
def filter_fixtures(all_fixtures, fixtures_base_dir, mark_fn=None, ignore_fn=None):
    """
    Helper function for filtering test fixtures.

    - `fixtures_base_dir` should be the base directory that the fixtures were collected from.
    - `mark_fn` should be a function which either returns `None` or a `pytest.mark` object.
    - `ignore_fn` should be a function which returns `True` for any fixture
       which should be ignored.
    """
    for fixture_data in all_fixtures:
        fixture_path = fixture_data[0]
        fixture_relpath = os.path.relpath(fixture_path, fixtures_base_dir)

        if ignore_fn:
            if ignore_fn(fixture_relpath, *fixture_data[1:]):
                continue

        if mark_fn is not None:
            mark = mark_fn(fixture_relpath, *fixture_data[1:])
            if mark:
                yield pytest.param(
                    (fixture_path, *fixture_data[1:]),
                    marks=mark,
                )
                continue

        yield fixture_data
项目:ml-utils    作者:LinxiFan    | 项目源码 | 文件源码
def param_fail(*args):
    return pytest.param(*args, marks=pytest.mark.xfail)
项目:ml-utils    作者:LinxiFan    | 项目源码 | 文件源码
def modarg(request):
    # request is a py.test builtin var that stores param
    param = request.param
    print ("  SETUP modarg %s" % param)
    yield param
    # anything after yield is "tearDown" logic
    print ("  TEARDOWN modarg %s" % param)
项目:ml-utils    作者:LinxiFan    | 项目源码 | 文件源码
def otherarg(request):
    param = request.param
    print ("  SETUP otherarg %s" % param)
    yield param
    print ("  TEARDOWN otherarg %s" % param)
项目:ml-utils    作者:LinxiFan    | 项目源码 | 文件源码
def param_fail(*args):
    """
    @pytest_demo.mark.parametrize("test_input,expected", [
        ("3+5", 8),
        ("2+4", 6),
        pytest_demo.param("6*9", 42,
                     marks=pytest_demo.mark.xfail),
        param_fail('7-8', -10),
    ])
    def test_eval(test_input, expected):
        assert eval(test_input) == expected
    """
    return pytest.param(*args, marks=pytest.mark.xfail)

# TODO write pytest_demo '==' hooks for numpy
项目:multidict    作者:aio-libs    | 项目源码 | 文件源码
def _multidict(request):
    return pytest.importorskip(request.param)
项目:setuptools    作者:pypa    | 项目源码 | 文件源码
def parametrize_test_working_set_resolve(*test_list):
    idlist = []
    argvalues = []
    for test in test_list:
        (
            name,
            installed_dists,
            installable_dists,
            requirements,
            expected1, expected2
        ) = [
            strip_comments(s.lstrip()) for s in
            textwrap.dedent(test).lstrip().split('\n\n', 5)
        ]
        installed_dists = list(parse_distributions(installed_dists))
        installable_dists = list(parse_distributions(installable_dists))
        requirements = list(pkg_resources.parse_requirements(requirements))
        for id_, replace_conflicting, expected in (
            (name, False, expected1),
            (name + '_replace_conflicting', True, expected2),
        ):
            idlist.append(id_)
            expected = strip_comments(expected.strip())
            if re.match('\w+$', expected):
                expected = getattr(pkg_resources, expected)
                assert issubclass(expected, Exception)
            else:
                expected = list(parse_distributions(expected))
            argvalues.append(pytest.param(installed_dists, installable_dists,
                                          requirements, replace_conflicting,
                                          expected))
    return pytest.mark.parametrize('installed_dists,installable_dists,'
                                   'requirements,replace_conflicting,'
                                   'resolved_dists_or_exception',
                                   argvalues, ids=idlist)
项目:setuptools    作者:pypa    | 项目源码 | 文件源码
def parametrize(*test_list, **format_dict):
            idlist = []
            argvalues = []
            for test in test_list:
                test_params = test.lstrip().split('\n\n', 3)
                name_kwargs = test_params.pop(0).split('\n')
                if len(name_kwargs) > 1:
                    val = name_kwargs[1].strip()
                    install_cmd_kwargs = ast.literal_eval(val)
                else:
                    install_cmd_kwargs = {}
                name = name_kwargs[0].strip()
                setup_py_requires, setup_cfg_requires, expected_requires = (
                    DALS(a).format(**format_dict) for a in test_params
                )
                for id_, requires, use_cfg in (
                    (name, setup_py_requires, False),
                    (name + '_in_setup_cfg', setup_cfg_requires, True),
                ):
                    idlist.append(id_)
                    marks = ()
                    if requires.startswith('@xfail\n'):
                        requires = requires[7:]
                        marks = pytest.mark.xfail
                    argvalues.append(pytest.param(requires, use_cfg,
                                                  expected_requires,
                                                  install_cmd_kwargs,
                                                  marks=marks))
            return pytest.mark.parametrize(
                'requires,use_setup_cfg,'
                'expected_requires,install_cmd_kwargs',
                argvalues, ids=idlist,
            )