Python unittest.mock 模块,patch() 实例源码

我们从Python开源项目中,提取了以下45个代码示例,用于说明如何使用unittest.mock.patch()

项目:jwt_apns_client    作者:Mobelux    | 项目源码 | 文件源码
def test_make_provider_token_calls_jwt_encode_with_correct_args(self):
        """
        Test that APNSConnection.make_provider_token() calls jwt.encode() with the correct
        arguments.  jwt.encode() returns different results each time even with the same data passed in
        so we cannot just test for the expected return value.
        """
        issued_at = time.time()
        connection = jwt_apns_client.APNSConnection(
            team_id='TEAMID',
            apns_key_id='KEYID',
            apns_key_path=self.KEY_FILE_PATH)

        with mock.patch('jwt_apns_client.utils.jwt.encode') as mock_encode:
            connection.make_provider_token(issued_at=issued_at)
            mock_encode.assert_called_with(
                {
                    'iss': connection.team_id,
                    'iat': issued_at
                },
                connection.secret,
                algorithm=connection.algorithm,
                headers=connection.get_token_headers())
项目:django-heartbeat    作者:pbs    | 项目源码 | 文件源码
def test_redis_status(self, mock_redis):
        setattr(settings, 'CACHEOPS_REDIS', {'host': 'foo', 'port': 1337})
        mock_redis.StrictRedis.return_value.ping.return_value = 'PONG'
        mock_redis.StrictRedis.return_value.info.return_value = {
            'redis_version': '1.0.0'}
        status = redis_status.check(request=None)
        assert status['ping'] == 'PONG'
        assert status['version'] == '1.0.0'

    # @mock.patch('heartbeat.checkers.redis_status.redis')
    # def test_redis_connection_error(self, mock_redis):
    #     setattr(settings, 'CACHEOPS_REDIS', {'host': 'foo', 'port': 1337})
    #     mock_ping = mock_redis.StrictRedis.return_value.ping
    #     mock_ping.side_effect = ConnectionError('foo')
    #     status = redis.check(request=None)
    #     assert status['error'] == 'foo', status
项目:django-heartbeat    作者:pbs    | 项目源码 | 文件源码
def test_db_version(self):
        import django
        if django.VERSION >= (1, 7):
            cursor = 'django.db.backends.utils.CursorWrapper'
        else:
            cursor = 'django.db.backends.util.CursorWrapper'
        with mock.patch(cursor) as mock_cursor:
            mock_cursor.return_value.fetchone.return_value = ['1.0.0']
            dbs = {
                'default': {
                    'ENGINE': 'django.db.backends.sqlite3',
                    'NAME': 'foo'
                }
            }
            setattr(settings, 'DATABASES', dbs)
            dbs = databases.check(request=None)
        assert len(dbs) == 1
        assert dbs[0]['version'] == '1.0.0'
项目:tts-bug-bounty-dashboard    作者:18F    | 项目源码 | 文件源码
def call_runscheduler(loops=1, mock_call_command=None):
    ctx = {'sleep_count': 0}

    def fake_sleep(seconds):
        ctx['sleep_count'] += 1
        if ctx['sleep_count'] > loops:
            raise KeyboardInterrupt()

    if mock_call_command is None:
        mock_call_command = mock.MagicMock()

    with mock.patch.object(runscheduler, 'call_command', mock_call_command):
        with mock.patch.object(runscheduler, 'logger') as mock_logger:
            with mock.patch('time.sleep', fake_sleep):
                with pytest.raises(KeyboardInterrupt):
                    call_command('runscheduler')
            return mock_call_command, mock_logger
项目:cxflow-tensorflow    作者:Cognexa    | 项目源码 | 文件源码
def test_args(self):
        """Test WriteTensorBoard argument handling and ``SummaryWriter`` creation."""
        model = self.get_model()

        with self.assertRaises(AssertionError):
            _ = WriteTensorBoard(output_dir=self.tmpdir, model=42)
        with self.assertRaises(AssertionError):
            _ = WriteTensorBoard(output_dir=self.tmpdir, model=model, on_missing_variable='not-recognized')
        with self.assertRaises(AssertionError):
            _ = WriteTensorBoard(output_dir=self.tmpdir, model=model, on_unknown_type='not-recognized')

        with mock.patch('tensorflow.summary.FileWriter', autospec=True) as mocked_writer:
            _ = WriteTensorBoard(output_dir=self.tmpdir, model=model, flush_secs=42, visualize_graph=True)
            mocked_writer.assert_called_with(logdir=self.tmpdir, flush_secs=42, graph=model.graph)

            _ = WriteTensorBoard(output_dir=self.tmpdir, model=model)
            mocked_writer.assert_called_with(logdir=self.tmpdir, flush_secs=10, graph=None)
项目:cxflow-tensorflow    作者:Cognexa    | 项目源码 | 文件源码
def test_missing_variable(self):
        """Test if ``WriteTensorBoard`` handles missing image variables as expected."""
        bad_epoch_data = {'valid': {}}

        with mock.patch.dict('sys.modules', **{'cv2': cv2_mock}):
            # test ignore
            hook = WriteTensorBoard(output_dir=self.tmpdir, model=self.get_model(), image_variables=['plot'],
                                    on_missing_variable='ignore')
            with LogCapture(level=logging.INFO) as log_capture:
                hook.after_epoch(42, bad_epoch_data)
            log_capture.check()

            # test warn
            warn_hook = WriteTensorBoard(output_dir=self.tmpdir, model=self.get_model(), image_variables=['plot'],
                                         on_missing_variable='warn')
            with LogCapture(level=logging.INFO) as log_capture2:
                warn_hook.after_epoch(42, bad_epoch_data)
            log_capture2.check(('root', 'WARNING', '`plot` not found in epoch data.'))

            # test error
            raise_hook = WriteTensorBoard(output_dir=self.tmpdir, model=self.get_model(), image_variables=['plot'],
                                          on_missing_variable='error')
            with self.assertRaises(KeyError):
                raise_hook.after_epoch(42, bad_epoch_data)
项目:latexipy    作者:masasin    | 项目源码 | 文件源码
def test_params_dict_after_font_size(self):
        with patch('matplotlib.rcParams.update') as mock_update, \
                patch('matplotlib.pyplot.switch_backend') as mock_switch:
            old_params = dict(plt.rcParams)
            with lp.temp_params(font_size=10, params_dict={
                    'axes.labelsize': 12,
                    'legend.fontsize': 12,
                    }):
                called_with = mock_update.call_args[0][0]
                assert called_with['font.size'] == 10
                assert called_with['axes.labelsize'] == 12
                assert called_with['axes.titlesize'] == 10
                assert called_with['legend.fontsize'] == 12
                assert called_with['xtick.labelsize'] == 10
                assert called_with['ytick.labelsize'] == 10

            mock_update.assert_called_with(old_params)
项目:latexipy    作者:masasin    | 项目源码 | 文件源码
def test_parameters_passed_custom_kwargs(self):
        params = inspect.signature(lp.figure).parameters

        with patch('matplotlib.figure.Figure.set_size_inches'), \
                patch('latexipy._latexipy.save_figure') as mock_save_figure:
            with lp.figure('filename', directory='directory', exts='exts',
                           mkdir='mkdir'):
                pass

            mock_save_figure.assert_called_once_with(
                filename='filename',
                directory='directory',
                exts='exts',
                mkdir='mkdir',
                from_context_manager=True,
            )
项目:pip-update-requirements    作者:alanhamlett    | 项目源码 | 文件源码
def setUp(self):
        # disable logging while testing
        logging.disable(logging.CRITICAL)

        self.patched = {}
        if hasattr(self, 'patch_these'):
            for patch_this in self.patch_these:
                namespace = patch_this[0] if isinstance(patch_this, (list, set)) else patch_this

                patcher = mock.patch(namespace)
                mocked = patcher.start()
                mocked.reset_mock()
                self.patched[namespace] = mocked

                if isinstance(patch_this, (list, set)) and len(patch_this) > 0:
                    retval = patch_this[1]
                    if callable(retval):
                        retval = retval()
                    mocked.return_value = retval
项目:xavier    作者:bepress    | 项目源码 | 文件源码
def test_publish_sns_event():

    TEST_ARN = 'arn:abc'
    TEST_MESSAGE = "message"
    with patch('xavier.aws.sns.send_sns_message') as mock_send_sns_message:
        mock_send_sns_message.return_value = {"MessageId": "1234"}

        message_publisher = publish_sns_message(TEST_ARN)

        message_publisher(TEST_MESSAGE)

        mock_send_sns_message.assert_called_once_with(TopicArn=TEST_ARN, Message=TEST_MESSAGE)

    with patch('xavier.aws.sns.send_sns_message') as mock_send_sns_message:
        mock_send_sns_message.return_value = None

        message_publisher = publish_sns_message(TEST_ARN)
        with pytest.raises(Exception):
            message_publisher(TEST_MESSAGE)

        mock_send_sns_message.assert_called_once_with(TopicArn=TEST_ARN, Message=TEST_MESSAGE)
项目:jenkins-epo    作者:peopledoc    | 项目源码 | 文件源码
def test_yml_invalid(mocker, SETTINGS):
    GITHUB = mocker.patch('jenkins_epo.extensions.core.GITHUB')
    from jenkins_epo.extensions.core import YamlExtension

    ext = YamlExtension('ext', Mock())
    ext.current = ext.bot.current
    ext.current.yaml = {}
    ext.current.errors = []

    GITHUB.fetch_file_contents = CoroutineMock(return_value="{INVALID")

    head = ext.current.head
    head.repository.url = 'https://github.com/owner/repo.git'
    head.repository.jobs = []

    yield from ext.run()

    assert GITHUB.fetch_file_contents.mock_calls
    assert ext.current.errors
项目:jenkins-epo    作者:peopledoc    | 项目源码 | 文件源码
def test_yml_found(mocker, SETTINGS):
    GITHUB = mocker.patch('jenkins_epo.extensions.core.GITHUB')
    Job = mocker.patch('jenkins_epo.extensions.core.Job')
    from jenkins_epo.extensions.core import YamlExtension

    Job.jobs_filter = ['*', '-skip']
    SETTINGS.update(YamlExtension.SETTINGS)

    ext = YamlExtension('ext', Mock())
    ext.current = ext.bot.current
    ext.current.yaml = {'job': dict()}

    GITHUB.fetch_file_contents = CoroutineMock(
        return_value="job: command\nskip: command",
    )

    head = ext.current.head
    head.repository.url = 'https://github.com/owner/repo.git'
    head.repository.jobs = {}

    yield from ext.run()

    assert GITHUB.fetch_file_contents.mock_calls
    assert 'job' in ext.current.job_specs
    assert 'skip' not in ext.current.job_specs
项目:gitnet    作者:networks-lab    | 项目源码 | 文件源码
def test_default(self):
        """Does the default method print the proper information?"""
        with patch('sys.stdout', new=StringIO()) as fake_out:
            self.my_log.describe(mode="default")
            output = fake_out.getvalue()
            self.assertIn("Log containing 4 records from local git created at ", output)
            self.assertIn("\nOrigin:", output)
            self.assertNotIn("Filters:", output)
            self.assertIn("\nNumber of authors: 4\n", output)
            self.assertIn("\nNumber of files: 7\n", output)
            self.assertIn("\nMost common email address domains:", output)
            self.assertIn("\n\t @gmail.com [4 users]\n", output)
            self.assertIn("\nDate range: 2016-05-20 09:19:20-04:00 to 2016-05-26 11:21:03-04:00\n", output)
            self.assertIn("\nChange distribution summary:\n", output)
            self.assertIn("\n\t Files changed: Mean = 2.75, SD = 0.829\n", output)
            self.assertIn("\n\t Line insertions: Mean = 2.75, SD = 0.829\n", output)
            self.assertIn("\n\t Line deletions: Mean = nan, SD = nan\n", output)
            self.assertIn("\nNumber of merges: 0\n", output)
            self.assertIn("\nNumber of parsing errors: 0\n", output)
项目:gitnet    作者:networks-lab    | 项目源码 | 文件源码
def test_not_default(self):
        """ Does a non-default method print the proper information?
            Note: At this point, default is the only setting so they end up being the same."""
        with patch('sys.stdout', new=StringIO()) as fake_out:
            self.my_log.describe(mode="not default")
            output = fake_out.getvalue()
            self.assertIn("Log containing 4 records from local git created at ", output)
            self.assertIn("\nOrigin:", output)
            self.assertNotIn("Filters:", output)
            self.assertIn("\nNumber of authors: 4\n", output)
            self.assertIn("\nNumber of files: 7\n", output)
            self.assertIn("\nMost common email address domains:", output)
            self.assertIn("\n\t @gmail.com [4 users]\n", output)
            self.assertIn("\nDate range: 2016-05-20 09:19:20-04:00 to 2016-05-26 11:21:03-04:00\n", output)
            self.assertIn("\nChange distribution summary:\n", output)
            self.assertIn("\n\t Files changed: Mean = 2.75, SD = 0.829\n", output)
            self.assertIn("\n\t Line insertions: Mean = 2.75, SD = 0.829\n", output)
            self.assertIn("\n\t Line deletions: Mean = nan, SD = nan\n", output)
            self.assertIn("\nNumber of merges: 0\n", output)
            self.assertIn("\nNumber of parsing errors: 0\n", output)
项目:gitnet    作者:networks-lab    | 项目源码 | 文件源码
def test_whole(self):
        """Is the entire output as expected?"""
        with patch('sys.stdout', new=StringIO()) as fake_out:
            self.my_log.describe()
            out = fake_out.getvalue()

            self.assertRegex(out, "Log containing 4 records from local git created at ....-..-.. ..:..:..\.......\.\n"
                                  "Origin:  .*\n"
                                  "Number of authors: 4\n"
                                  "Number of files: 7\n"
                                  "Most common email address domains:\n"
                                  "\t @gmail.com \[4 users\]\n"
                                  "Date range: 2016-05-20 09:19:20-04:00 to 2016-05-26 11:21:03-04:00\n"
                                  "Change distribution summary:\n"
                                  "\t Files changed: Mean = 2.75, SD = 0.829\n"
                                  "\t Line insertions: Mean = 2.75, SD = 0.829\n"
                                  "\t Line deletions: Mean = nan, SD = nan\n"
                                  "Number of merges: 0\n"
                                  "Number of parsing errors: 0\n")
项目:gitnet    作者:networks-lab    | 项目源码 | 文件源码
def test_exclude(self):
        """Does exclude prevent statistics from being printed?"""
        with patch('sys.stdout', new=StringIO()) as fake_out:
            self.my_log.describe(exclude=['merges', 'errors', 'files', 'summary', 'changes', 'path', 'filters',
                                          'authors', 'dates', 'emails'])
            output = fake_out.getvalue()
            self.assertNotIn("Log containing 4 records from local git created at ", output)
            self.assertNotIn("\nOrigin:", output)
            self.assertNotIn("Filters:", output)
            self.assertNotIn("\nNumber of authors: 4\n", output)
            self.assertNotIn("\nNumber of files: 7\n", output)
            self.assertNotIn("\nMost common email address domains:", output)
            self.assertNotIn("\n\t @gmail.com [4 users]\n", output)
            self.assertNotIn("\nDate range: 2016-05-20 09:19:20-04:00 to 2016-05-26 11:21:03-04:00\n", output)
            self.assertNotIn("\nChange distribution summary:\n", output)
            self.assertNotIn("\n\t Files changed: Mean = 2.75, SD = 0.829\n", output)
            self.assertNotIn("\n\t Line insertions: Mean = 2.75, SD = 0.829\n", output)
            self.assertNotIn("\n\t Line deletions: Mean = nan, SD = nan\n", output)
            self.assertNotIn("\nNumber of merges: 0\n", output)
            self.assertNotIn("\nNumber of parsing errors: 0\n", output)
            self.assertEqual(output, "")
项目:zinc    作者:PressLabs    | 项目源码 | 文件源码
def test_update_record_values(api_client, zone, boto_client):
    G(m.Zone)

    record_data = {
        'values': ['1.2.3.4']
    }
    response = api_client.patch(
        '/zones/{}/records/{}'.format(zone.id, hash_test_record(zone)),
        data=record_data)

    assert response.data == {
        **get_test_record(zone),
        **record_data
    }
    assert aws_strip_ns_and_soa(
        boto_client.list_resource_record_sets(HostedZoneId=zone.r53_zone.id), zone.root
    ) == sorted([
        record_data_to_aws({
            **get_test_record(zone),
            **record_data
        }, zone.root)
    ], key=aws_sort_key)
项目:zinc    作者:PressLabs    | 项目源码 | 文件源码
def test_update_record_ttl(api_client, zone, boto_client):
    G(m.Zone)

    record_data = {
        'ttl': 580
    }
    response = api_client.patch(
        '/zones/{}/records/{}'.format(zone.id, hash_test_record(zone)),
        data=record_data
    )

    assert response.data == {
        **get_test_record(zone),
        **record_data
    }
    assert aws_strip_ns_and_soa(
        boto_client.list_resource_record_sets(HostedZoneId=zone.r53_zone.id), zone.root
    ) == sorted([
        record_data_to_aws({
            **get_test_record(zone),
            **record_data
        }, zone.root)
    ], key=aws_sort_key)
项目:sauna    作者:NicolasLM    | 项目源码 | 文件源码
def test_get_checks_as_dict(self):
        foo = ServiceCheck(timestamp=42, hostname='server1',
                           name='foo', status=0, output='foo out')
        bar = ServiceCheck(timestamp=42, hostname='server1',
                           name='bar', status=1, output='bar out')
        with mock.patch.dict('sauna.check_results', foo=foo, bar=bar):
            self.assertDictEqual(base.AsyncConsumer.get_checks_as_dict(), {
                'foo': {
                    'status': 'OK',
                    'code': 0,
                    'timestamp': 42,
                    'output': 'foo out'
                },
                'bar': {
                    'status': 'WARNING',
                    'code': 1,
                    'timestamp': 42,
                    'output': 'bar out'
                }
            })
项目:helper_scripts    作者:pythonanywhere    | 项目源码 | 文件源码
def test_calls_all_stuff_in_right_order(self):
        with patch('scripts.pa_start_django_webapp_with_virtualenv.DjangoProject') as mock_DjangoProject:
            main('www.domain.com', 'django.version', 'python.version', nuke='nuke option')
        assert mock_DjangoProject.call_args == call('www.domain.com')
        assert mock_DjangoProject.return_value.method_calls == [
            call.sanity_checks(nuke='nuke option'),
            call.create_virtualenv('python.version', 'django.version', nuke='nuke option'),
            call.run_startproject(nuke='nuke option'),
            call.find_django_files(),
            call.update_settings_file(),
            call.run_collectstatic(),
            call.create_webapp(nuke='nuke option'),
            call.add_static_file_mappings(),
            call.update_wsgi_file(),
            call.webapp.reload(),
        ]
项目:helper_scripts    作者:pythonanywhere    | 项目源码 | 文件源码
def test_actually_creates_django_project_in_virtualenv_with_hacked_settings_and_static_files(
        self, fake_home, virtualenvs_folder, api_token
    ):

        with patch('scripts.pa_start_django_webapp_with_virtualenv.DjangoProject.update_wsgi_file'):
            with patch('pythonanywhere.api.call_api'):
                main('mydomain.com', '1.9.2', '2.7', nuke=False)

        django_version = subprocess.check_output([
            virtualenvs_folder / 'mydomain.com/bin/python',
            '-c'
            'import django; print(django.get_version())'
        ]).decode().strip()
        assert django_version == '1.9.2'

        with open(fake_home / 'mydomain.com/mysite/settings.py') as f:
            lines = f.read().split('\n')
        assert "MEDIA_ROOT = os.path.join(BASE_DIR, 'media')" in lines
        assert "ALLOWED_HOSTS = ['mydomain.com']" in lines

        assert 'base.css' in os.listdir(fake_home / 'mydomain.com/static/admin/css')
项目:mistletoe    作者:miyuchina    | 项目源码 | 文件源码
def test_interactive(self, mock_print, mock_markdown,
                         mock_print_heading, mock_import_readline):
        def MockInputFactory(return_values):
            _counter = -1
            def mock_input(prompt=''):
                nonlocal _counter
                _counter += 1
                if _counter < len(return_values):
                    return return_values[_counter]
                elif _counter == len(return_values):
                    raise EOFError
                else:
                    raise KeyboardInterrupt
            return mock_input

        return_values = ['foo', 'bar', 'baz']
        with patch('builtins.input', MockInputFactory(return_values)):
            cli.interactive(sentinel.RendererCls)

        mock_import_readline.assert_called_with()
        mock_print_heading.assert_called_with(sentinel.RendererCls)
        mock_markdown.assert_called_with(['foo\n', 'bar\n', 'baz\n'],
                sentinel.RendererCls)
        calls = [call('\nrendered text', end=''), call('\nExiting.')]
        mock_print.assert_has_calls(calls)
项目:django-webpack    作者:csinchok    | 项目源码 | 文件源码
def test_munge_config(self):

        def mock_find(path):
            return os.path.join('/some/path/static/', path)

        with mock.patch('webpack.conf.find', new=mock_find) as find:

            munged = get_munged_config(WEBPACK_CONFIG)

        expected_output = WEBPACK_CONFIG_OUTPUT.format(
            url=settings.STATIC_URL,
            root=settings.STATIC_ROOT
        )
        self.assertEqual(
            munged,
            expected_output
        )
项目:release-script    作者:mitodl    | 项目源码 | 文件源码
def test_init_working_dir():
    """init_working_dir should initialize a valid git repo, and clean up after itself"""
    repo_url = "https://github.com/mitodl/release-script.git"
    access_token = 'fake_access_token'
    with patch('release.check_call', autospec=True) as check_call_mock, init_working_dir(
        access_token, repo_url,
    ) as other_directory:
        assert os.path.exists(other_directory)
    assert not os.path.exists(other_directory)

    calls = check_call_mock.call_args_list
    assert [call[0][0] for call in calls] == [
        ['git', 'init'],
        ['git', 'remote', 'add', 'origin', url_with_access_token(access_token, repo_url)],
        ['git', 'fetch'],
        ['git', 'checkout', '-t', 'origin/master'],
    ]
项目:antenna    作者:mozilla-services    | 项目源码 | 文件源码
def randommock():
    """Returns a contextmanager that mocks random.random() at a specific value

    Usage::

        def test_something(randommock):
            with randommock(0.55):
                # test stuff...

    """
    @contextlib.contextmanager
    def _randommock(value):
        with mock.patch('random.random') as mock_random:
            mock_random.return_value = value
            yield

    return _randommock
项目:annotated-py-asyncio    作者:hhstore    | 项目源码 | 文件源码
def test_set_exc_handler_broken(self):
        def run_loop():
            def zero_error():
                1/0
            self.loop.call_soon(zero_error)
            self.loop._run_once()

        def handler(loop, context):
            raise AttributeError('spam')

        self.loop._process_events = mock.Mock()

        self.loop.set_exception_handler(handler)

        with mock.patch('asyncio.base_events.logger') as log:
            run_loop()
            log.error.assert_called_with(
                test_utils.MockPattern(
                    'Unhandled error in exception handler'),
                exc_info=(AttributeError, MOCK_ANY, MOCK_ANY))
项目:annotated-py-asyncio    作者:hhstore    | 项目源码 | 文件源码
def test_create_connection_timeout(self, m_socket):
        # Ensure that the socket is closed on timeout
        sock = mock.Mock()
        m_socket.socket.return_value = sock

        def getaddrinfo(*args, **kw):
            fut = asyncio.Future(loop=self.loop)
            addr = (socket.AF_INET, socket.SOCK_STREAM, 0, '',
                    ('127.0.0.1', 80))
            fut.set_result([addr])
            return fut
        self.loop.getaddrinfo = getaddrinfo

        with mock.patch.object(self.loop, 'sock_connect',
                               side_effect=asyncio.TimeoutError):
            coro = self.loop.create_connection(MyProto, '127.0.0.1', 80)
            with self.assertRaises(asyncio.TimeoutError):
                self.loop.run_until_complete(coro)
            self.assertTrue(sock.close.called)
项目:annotated-py-asyncio    作者:hhstore    | 项目源码 | 文件源码
def setUp(self):
        self.loop = self.new_test_loop()
        self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
        self.pipe = mock.Mock(spec_set=io.RawIOBase)
        self.pipe.fileno.return_value = 5

        blocking_patcher = mock.patch('asyncio.unix_events._set_nonblocking')
        blocking_patcher.start()
        self.addCleanup(blocking_patcher.stop)

        fstat_patcher = mock.patch('os.fstat')
        m_fstat = fstat_patcher.start()
        st = mock.Mock()
        st.st_mode = stat.S_IFIFO
        m_fstat.return_value = st
        self.addCleanup(fstat_patcher.stop)
项目:annotated-py-asyncio    作者:hhstore    | 项目源码 | 文件源码
def setUp(self):
        self.loop = self.new_test_loop()
        self.protocol = test_utils.make_test_protocol(asyncio.BaseProtocol)
        self.pipe = mock.Mock(spec_set=io.RawIOBase)
        self.pipe.fileno.return_value = 5

        blocking_patcher = mock.patch('asyncio.unix_events._set_nonblocking')
        blocking_patcher.start()
        self.addCleanup(blocking_patcher.stop)

        fstat_patcher = mock.patch('os.fstat')
        m_fstat = fstat_patcher.start()
        st = mock.Mock()
        st.st_mode = stat.S_IFSOCK
        m_fstat.return_value = st
        self.addCleanup(fstat_patcher.stop)
项目:annotated-py-asyncio    作者:hhstore    | 项目源码 | 文件源码
def waitpid_mocks(func):
        def wrapped_func(self):
            def patch(target, wrapper):
                return mock.patch(target, wraps=wrapper,
                                  new_callable=mock.Mock)

            with patch('os.WTERMSIG', self.WTERMSIG) as m_WTERMSIG, \
                 patch('os.WEXITSTATUS', self.WEXITSTATUS) as m_WEXITSTATUS, \
                 patch('os.WIFSIGNALED', self.WIFSIGNALED) as m_WIFSIGNALED, \
                 patch('os.WIFEXITED', self.WIFEXITED) as m_WIFEXITED, \
                 patch('os.waitpid', self.waitpid) as m_waitpid:
                func(self, WaitPidMocks(m_waitpid,
                                        m_WIFEXITED, m_WIFSIGNALED,
                                        m_WEXITSTATUS, m_WTERMSIG,
                                        ))
        return wrapped_func
项目:triage    作者:dssg    | 项目源码 | 文件源码
def test_build_error(experiment_class):
    with testing.postgresql.Postgresql() as postgresql:
        db_engine = create_engine(postgresql.url())
        ensure_db(db_engine)

        with TemporaryDirectory() as temp_dir:
            experiment = experiment_class(
                config=sample_config(),
                db_engine=db_engine,
                model_storage_class=FSModelStorageEngine,
                project_path=os.path.join(temp_dir, 'inspections'),
            )

            with mock.patch.object(experiment, 'build_matrices') as build_mock:
                build_mock.side_effect = RuntimeError('boom!')

                with pytest.raises(RuntimeError):
                    experiment()
项目:triage    作者:dssg    | 项目源码 | 文件源码
def test_build_error_cleanup_timeout(_clean_up_mock, experiment_class):
    with testing.postgresql.Postgresql() as postgresql:
        db_engine = create_engine(postgresql.url())
        ensure_db(db_engine)

        with TemporaryDirectory() as temp_dir:
            experiment = experiment_class(
                config=sample_config(),
                db_engine=db_engine,
                model_storage_class=FSModelStorageEngine,
                project_path=os.path.join(temp_dir, 'inspections'),
                cleanup_timeout=0.02,  # Set short timeout
            )

            with mock.patch.object(experiment, 'build_matrices') as build_mock:
                build_mock.side_effect = RuntimeError('boom!')

                with pytest.raises(TimeoutError) as exc_info:
                    experiment()

    # Last exception is TimeoutError, but earlier error is preserved in
    # __context__, and will be noted as well in any standard traceback:
    assert exc_info.value.__context__ is build_mock.side_effect
项目:image-quantizer    作者:se7entyse7en    | 项目源码 | 文件源码
def test_compare(self):
        q = quantizer.ImageQuantizer()

        qimages = q.quantize_multi([
            {'n_colors': 8, 'method': 'random'},
            {'n_colors': 16, 'method': 'random'},
            {'n_colors': 8, 'method': 'kmeans'},
            {'n_colors': 16, 'method': 'kmeans'},
            {'n_colors': 8, 'method': 'random+lab'},
            {'n_colors': 16, 'method': 'random+lab'},
            {'n_colors': 8, 'method': 'kmeans+lab'},
            {'n_colors': 16, 'method': 'kmeans+lab'},
        ], image_filename=self._get_image_path('Lenna.png'))

        with mock.patch('image_quantizer.quantizer.plt.show', lambda: None):
            quantizer.compare(*qimages)
项目:valhalla    作者:LCOGT    | 项目源码 | 文件源码
def setUp(self):
        super().setUp()

        self.now = datetime(year=2017, month=5, day=12, hour=10, tzinfo=timezone.utc)

        self.timezone_patch = patch('valhalla.userrequests.contention.timezone')
        self.mock_timezone = self.timezone_patch.start()
        self.mock_timezone.now.return_value = self.now

        self.site_intervals_patch = patch('valhalla.userrequests.contention.get_site_rise_set_intervals')
        self.mock_site_intervals = self.site_intervals_patch.start()

        for i in range(24):
            request = mixer.blend(Request, state='PENDING')
            mixer.blend(
                Window, start=timezone.now(), end=timezone.now() + timedelta(hours=i), request=request
            )
            mixer.blend(
                Target, ra=random.randint(0, 360), dec=random.randint(-180, 180),
                proper_motion_ra=0.0, proper_motion_dec=0.0, type='SIDEREAL', request=request
            )
            mixer.blend(Molecule, instrument_name='1M0-SCICAM-SBIG', request=request)
            mixer.blend(Location, request=request)
            mixer.blend(Constraints, request=request)
项目:socialhome    作者:jaywink    | 项目源码 | 文件源码
def test_hcard_responds_on_404_on_unknown_user(self, client):
        response = client.get(reverse("federate:hcard", kwargs={"guid": "fehwuyfehiufhewiuhfiuhuiewfew"}))
        assert response.status_code == 404
        with patch("socialhome.federate.views.get_object_or_404") as mock_get:
            # Test also ValueError raising ending up as 404
            Profile.objects.filter(user__username="foobar").update(rsa_public_key="fooobar")
            mock_get.side_effect = ValueError()
            response = client.get(reverse("federate:hcard", kwargs={"guid": "foobar"}))
            assert response.status_code == 404
项目:socialhome    作者:jaywink    | 项目源码 | 文件源码
def test_pyembed_errors_swallowed(self):
        for error in [PyEmbedError, PyEmbedDiscoveryError, PyEmbedConsumerError, ValueError]:
            with patch("socialhome.content.previews.PyEmbed.embed", side_effect=error):
                result = fetch_oembed_preview(self.content, self.urls)
                self.assertFalse(result)
项目:flash_services    作者:textbook    | 项目源码 | 文件源码
def test_define_services(uuid4):
    mock_service = mock.MagicMock()
    with mock.patch.dict(SERVICES, {'bar': mock_service}, clear=True):

        result = define_services([{'name': 'bar'}, {'name': 'baz'}])

    uuid4.assert_called_once_with()
    assert result == {uuid4().hex: mock_service.from_config.return_value}
项目:Easysentiment    作者:Jflick58    | 项目源码 | 文件源码
def test_cli(argv):
    """test function."""
    with mock.patch('easysentiment.cli.scrape') as m_scrape, \
            mock.patch('easysentiment.cli.scrape_and_analyze') as m_saa, \
            mock.patch('easysentiment.cli.analyze_sentiment') as m_as:
        mock_dict = {
            'scrape': m_scrape,
            'scrape-and-analyze': m_saa,
            'analyze-sentiment': m_as,
        }
        from easysentiment.cli import cli
        cli([argv])
        mock_dict[argv].assert_called_once_with()
项目:django-heartbeat    作者:pbs    | 项目源码 | 文件源码
def test_build_version_with_valid_package_name(self):
        package = Mock(project_name='foo', version='1.0.0')
        setattr(settings, 'HEARTBEAT', {'package_name': 'foo'})
        with mock.patch.object(build.WorkingSet, 'find', return_value=package):
            distro = build.check(request=None)
            assert distro == {'name': 'foo', 'version': '1.0.0'}
项目:tts-bug-bounty-dashboard    作者:18F    | 项目源码 | 文件源码
def call_h1sync(*args, reports=None):
    if reports is None:
        reports = []
    with mock.patch('dashboard.h1.find_reports') as mock_find_reports:
        mock_find_reports.return_value = reports
        out = io.StringIO()
        call_command('h1sync', *args, stdout=out)
        return out.getvalue(), mock_find_reports
项目:foremast    作者:gogoair    | 项目源码 | 文件源码
def test_default_security_groups(mock_properties, mock_details):
    """Make sure default Security Groups are added to the ingress rules."""
    ingress = {
        'test_app': [
            {
                'start_port': 30,
                'end_port': 30,
            },
        ],
    }

    mock_properties.return_value = {
        'security_group': {
            'ingress': ingress,
            'description': '',
        },
    }

    test_sg = {
        'myapp': [
            {
                'start_port': '22',
                'end_port': '22',
                'protocol': 'tcp'
            },
        ]
    }
    with mock.patch.dict('foremast.securitygroup.create_securitygroup.DEFAULT_SECURITYGROUP_RULES', test_sg):
        sg = SpinnakerSecurityGroup()
        ingress = sg.update_default_rules()
        assert 'myapp' in ingress
项目:cxflow-tensorflow    作者:Cognexa    | 项目源码 | 文件源码
def test_write(self):
        """Test if ``WriteTensorBoard`` writes to its FileWriter."""
        hook = WriteTensorBoard(output_dir=self.tmpdir, model=self.get_model())
        with mock.patch.object(tf.summary.FileWriter, 'add_summary') as mocked_add_summary:
            hook.after_epoch(42, {})
            self.assertEqual(mocked_add_summary.call_count, 1)
            hook.after_epoch(43, {'valid': {'accuracy': 1.0}})
            self.assertEqual(mocked_add_summary.call_count, 2)
        hook.after_epoch(44, {'valid': {'accuracy': {'mean': np.float32(1.0)}}})
        hook.after_epoch(45, {'valid': {'accuracy': {'nanmean': 1.0}}})
        hook._summary_writer.close()
项目:cxflow-tensorflow    作者:Cognexa    | 项目源码 | 文件源码
def test_image_variable(self):
        """Test if ``WriteTensorBoard`` checks the image variables properly."""
        hook = WriteTensorBoard(output_dir=self.tmpdir, model=self.get_model(), image_variables=['plot'])

        with mock.patch.dict('sys.modules', **{'cv2': cv2_mock}):
            with self.assertRaises(AssertionError):
                hook.after_epoch(0, {'train': {'plot': [None]}})

            with self.assertRaises(AssertionError):
                hook.after_epoch(1, {'train': {'plot': np.zeros((10,))}})

            hook.after_epoch(2, {'train': {'plot': np.zeros((10, 10, 3))}})
        hook._summary_writer.close()
项目:cxflow-tensorflow    作者:Cognexa    | 项目源码 | 文件源码
def test_unknown_type(self):
        """Test if ``WriteTensorBoard`` handles unknown variable types as expected."""
        bad_epoch_data = {'valid': {'accuracy': 'bad_type'}}

        # test ignore
        hook = WriteTensorBoard(output_dir=self.tmpdir, model=self.get_model())
        with LogCapture(level=logging.INFO) as log_capture:
            hook.after_epoch(42, bad_epoch_data)
        log_capture.check()

        # test warn
        warn_hook = WriteTensorBoard(output_dir=self.tmpdir, model=self.get_model(), on_unknown_type='warn')
        with LogCapture(level=logging.INFO) as log_capture2:
            warn_hook.after_epoch(42, bad_epoch_data)
        log_capture2.check(('root', 'WARNING', 'Variable `accuracy` in stream `valid` has to be of type `int` '
                                               'or `float` (or a `dict` with a key named `mean` or `nanmean` '
                                               'whose corresponding value is of type `int` or `float`), '
                                               'found `<class \'str\'>` instead.'))

        # test error
        raise_hook = WriteTensorBoard(output_dir=self.tmpdir, model=self.get_model(), on_unknown_type='error')
        with self.assertRaises(ValueError):
            raise_hook.after_epoch(42, bad_epoch_data)

        with mock.patch.dict('sys.modules', **{'cv2': cv2_mock}):
            # test skip image variables
            skip_hook = WriteTensorBoard(output_dir=self.tmpdir, model=self.get_model(), on_unknown_type='error',
                                         image_variables=['accuracy'])
            skip_hook.after_epoch(42, {'valid': {'accuracy': np.zeros((10, 10, 3))}})
            skip_hook._summary_writer.close()
项目:bob    作者:BobBuildTool    | 项目源码 | 文件源码
def testBigIno(self):
        """Test that index handles big inode numbers as found on Windows"""

        s = MagicMock()
        s.st_mode=33188
        s.st_ino=-5345198597064824875
        s.st_dev=65027
        s.st_nlink=1
        s.st_uid=1000
        s.st_gid=1000
        s.st_size=3
        s.st_atime=1452798827
        s.st_mtime=1452798827
        s.st_ctime=1452798827
        mock_lstat = MagicMock()
        mock_lstat.return_value = s

        with NamedTemporaryFile() as index:
            with TemporaryDirectory() as tmp:
                with open(os.path.join(tmp, "ghost"), 'wb') as f:
                    f.write(b'abc')

                with patch('os.lstat', mock_lstat):
                    hashDirectory(tmp, index.name)

                with open(index.name, "rb") as f:
                    assert f.read(4) == b'BOB1'