Python toolz 模块,merge() 实例源码

我们从Python开源项目中,提取了以下50个代码示例,用于说明如何使用toolz.merge()

项目:dl4mt-multi    作者:nyu-dl    | 项目源码 | 文件源码
def print_training_params(self, cgs, training_params):

        enc_dec_param_dict = merge(self.encoder.get_params(),
                                   self.decoder.get_params())

        # Print which parameters are excluded
        for k, v in cgs.iteritems():
            excluded_all = list(set(v.parameters) - set(training_params[k]))
            for p in excluded_all:
                logger.info(
                    'Excluding from training of CG[{}]: {}'
                    .format(k, [key for key, val in
                                enc_dec_param_dict.iteritems()
                                if val == p][0]))
            logger.info(
                'Total number of excluded parameters for CG[{}]: [{}]'
                .format(k, len(excluded_all)))

        for k, v in training_params.iteritems():
            for p in v:
                logger.info('Training parameter from CG[{}]: {}'
                            .format(k, p.name))
            logger.info(
                'Total number of parameters will be trained for CG[{}]: [{}]'
                .format(k, len(v)))
项目:adt    作者:llllllllll    | 项目源码 | 文件源码
def scrutinize(self, scrutine, context_frame):
        constructor = type(scrutine)
        if constructor.__name__ != self._constructor_name:
            raise NoMatch()

        kwargs = scrutine._kwargs
        # the context to evaluate the thunk in
        context = {
            Call(Normal(name_lookup), (Normal(name),), {}): Normal(value)
            for name, value in merge(
                vars(builtins),
                context_frame.f_globals,
                context_frame.f_locals,
                # the newly bound arguments have the highest precedence
                dict(zip(self._argnames, scrutine._args)),
                {v: kwargs[k] for k, v in self._kwargnames.items()},
            ).items()
        }
        bound_tree = LTree.parse(self._expr).subs(context)
        return strict(bound_tree.lcompile())
项目:catalyst    作者:enigmampc    | 项目源码 | 文件源码
def run_example(example_name, environ):
    """
    Run an example module from catalyst.examples.
    """
    mod = EXAMPLE_MODULES[example_name]

    register_calendar("YAHOO", get_calendar("NYSE"), force=True)

    return run_algorithm(
        initialize=getattr(mod, 'initialize', None),
        handle_data=getattr(mod, 'handle_data', None),
        before_trading_start=getattr(mod, 'before_trading_start', None),
        analyze=getattr(mod, 'analyze', None),
        bundle='test',
        environ=environ,
        # Provide a default capital base, but allow the test to override.
        **merge({'capital_base': 1e7}, mod._test_args())
    )
项目:pandahouse    作者:kszucs    | 项目源码 | 文件源码
def prepare(query, connection=None, external=None):
    connection = merge(_default, connection or {})
    database = escape(connection['database'])
    query = query.format(db=database)
    params = {'query': query,
              'user': connection['user'],
              'password': connection['password']}
    params = valfilter(lambda x: x, params)

    files = {}
    external = external or {}
    for name, (structure, serialized) in external.items():
        params['{}_format'.format(name)] = 'CSV'
        params['{}_structure'.format(name)] = structure
        files[name] = serialized

    host = connection['host']

    return host, params, files
项目:sgnmt    作者:ucam-smt    | 项目源码 | 文件源码
def apply(self, source_sentence, source_sentence_mask):
        """Creates bidirectional RNN source annotations.

        Args:
            source_sentence (Variable): Source sentence with words in
                                        vector representation.
            source_sentence_mask (Variable): Source mask

        Returns:
            Variable. source annotations
        """
        # Time as first dimension
        source_sentence = source_sentence.T
        source_sentence_mask = source_sentence_mask.T

        representation = self.bidir.apply(
            merge(self.fwd_fork.apply(source_sentence, as_dict=True),
                  {'mask': source_sentence_mask}),
            merge(self.back_fork.apply(source_sentence, as_dict=True),
                  {'mask': source_sentence_mask})
        )
        return representation, source_sentence_mask
项目:dl4mt-multi-src    作者:nyu-dl    | 项目源码 | 文件源码
def print_training_params(self, cgs, training_params):

        enc_dec_param_dict = merge(self.encoder.get_params(),
                                   self.decoder.get_params())

        # Print which parameters are excluded
        for k, v in cgs.iteritems():
            excluded_all = list(set(v.parameters) - set(training_params[k]))
            for p in excluded_all:
                logger.info(
                    'Excluding from training of CG[{}]: {}'
                    .format(k, [key for key, val in
                                enc_dec_param_dict.iteritems()
                                if val == p][0]))
            logger.info(
                'Total number of excluded parameters for CG[{}]: [{}]'
                .format(k, len(excluded_all)))

        for k, v in training_params.iteritems():
            for p in v:
                logger.info('Training parameter from CG[{}]: {}'
                            .format(k, p.name))
            logger.info(
                'Total number of parameters will be trained for CG[{}]: [{}]'
                .format(k, len(v)))
项目:zipline-chinese    作者:zhanghan1990    | 项目源码 | 文件源码
def __init__(self, key, mapping=None, **kwargs):
        self._map = {}
        self._sorted_key_names = []
        self._sort_key = key

        self.update(merge(mapping or {}, kwargs))
项目:zipline-chinese    作者:zhanghan1990    | 项目源码 | 文件源码
def load_adjusted_array(self, columns, dates, assets, mask):
        return merge(
            self.get_loader(column).load_adjusted_array(
                [column], dates, assets, mask
            )
            for column in columns
        )
项目:zipline-chinese    作者:zhanghan1990    | 项目源码 | 文件源码
def test_ewm_stats(self, window_length):

        def ewma_name(decay_rate):
            return 'ewma_%s' % decay_rate

        def ewmstd_name(decay_rate):
            return 'ewmstd_%s' % decay_rate

        decay_rates = [0.25, 0.5, 0.75]
        ewmas = {
            ewma_name(decay_rate): EWMA(
                inputs=(USEquityPricing.close,),
                window_length=window_length,
                decay_rate=decay_rate,
            )
            for decay_rate in decay_rates
        }

        ewmstds = {
            ewmstd_name(decay_rate): EWMSTD(
                inputs=(USEquityPricing.close,),
                window_length=window_length,
                decay_rate=decay_rate,
            )
            for decay_rate in decay_rates
        }

        all_results = self.engine.run_pipeline(
            Pipeline(columns=merge(ewmas, ewmstds)),
            self.dates[window_length],
            self.dates[-1],
        )

        for decay_rate in decay_rates:
            ewma_result = all_results[ewma_name(decay_rate)].unstack()
            ewma_expected = self.expected_ewma(window_length, decay_rate)
            assert_frame_equal(ewma_result, ewma_expected)

            ewmstd_result = all_results[ewmstd_name(decay_rate)].unstack()
            ewmstd_expected = self.expected_ewmstd(window_length, decay_rate)
            assert_frame_equal(ewmstd_result, ewmstd_expected)
项目:dl4mt-multi    作者:nyu-dl    | 项目源码 | 文件源码
def do(self, which_callback, *args):
        iterations_done = self.main_loop.status['iterations_done']
        if self.burnin <= iterations_done:
            # Save the model here
            iterations_done = self.main_loop.status['iterations_done']
            filename = os.path.join(
                self.saveto, 'params_iter{}.npz'.format(iterations_done))
            s = signal.signal(signal.SIGINT, signal.SIG_IGN)
            logger.info(" Incremental dump {}".format(filename))
            params_to_save = []
            for cg_name in self.main_loop.models.keys():
                params_to_save.append(
                    self.main_loop.models[cg_name].get_param_values())
            params_to_save = merge(params_to_save)
            secure_numpy_save(params_to_save, filename)
            if self.save_iter_state:
                filename_is = os.path.join(
                    self.saveto,
                    'iterations_state_iter{}.pkl'.format(iterations_done))
                logger.info(" Incremental dump {}".format(filename_is))
                secure_pickle_dump(self.main_loop.iteration_state, filename_is)
            if self.save_log:
                filename_log = os.path.join(
                    self.saveto,
                    'log_iter{}'.format(iterations_done))
                logger.info(" Incremental dump {}".format(filename_log))
                secure_pickle_dump(self.main_loop.log, filename_log)
            signal.signal(signal.SIGINT, s)
项目:dl4mt-multi    作者:nyu-dl    | 项目源码 | 文件源码
def dump_parameters(self, main_loop):
        params_to_save = []
        for model in main_loop.models.values():
            params_to_save.append(model.get_param_values())
        secure_numpy_save(merge(params_to_save),
                          self.path_to_parameters)
项目:dl4mt-multi    作者:nyu-dl    | 项目源码 | 文件源码
def get_params(self):
        return merge(self.encoder.get_params(),
                     self.decoder.get_params())
项目:dl4mt-multi    作者:nyu-dl    | 项目源码 | 文件源码
def _save_model(self, bleu_score):
        if self._is_valid_to_save(bleu_score):
            model = ModelInfo(
                bleu_score, self.saveto, self.enc_id, self.dec_id)

            # Manage n-best model list first
            if len(self.best_models) >= self.track_n_models:
                old_model = self.best_models[0]
                if old_model.path and os.path.isfile(old_model.path):
                    logger.info("Deleting old model %s" % old_model.path)
                    os.remove(old_model.path)
                self.best_models.remove(old_model)

            self.best_models.append(model)
            self.best_models.sort(key=operator.attrgetter('bleu_score'))

            # Save the model here
            s = signal.signal(signal.SIGINT, signal.SIG_IGN)
            logger.info("Saving new model {}".format(model.path))
            params_to_save = []
            for cg_name in self.main_loop.models.keys():
                params_to_save.append(
                    self.main_loop.models[cg_name].get_param_values())
            params_to_save = merge(params_to_save)

            self._save_params(model, params_to_save)
            self._save_bleu_scores()

            signal.signal(signal.SIGINT, s)
项目:dask_gdf    作者:gpuopenanalytics    | 项目源码 | 文件源码
def index(self):
        """Return dask Index instance"""
        name = self._name + '-index'
        dsk = {(name, i): (getattr, key, 'index')
               for i, key in enumerate(self._keys())}
        return Index(merge(dsk, self.dask), name,
                     self._meta.index, self.divisions)
项目:dask_gdf    作者:gpuopenanalytics    | 项目源码 | 文件源码
def head(self, n=5, npartitions=1, compute=True):
        """ First n rows of the dataset

        Parameters
        ----------
        n : int, optional
            The number of rows to return. Default is 5.
        npartitions : int, optional
            Elements are only taken from the first ``npartitions``, with a
            default of 1. If there are fewer than ``n`` rows in the first
            ``npartitions`` a warning will be raised and any found rows
            returned. Pass -1 to use all partitions.
        compute : bool, optional
            Whether to compute the result, default is True.
        """
        if npartitions <= -1:
            npartitions = self.npartitions
        if npartitions > self.npartitions:
            raise ValueError("only %d partitions, received "
                             "%d" % (self.npartitions, npartitions))

        name = 'head-%d-%d-%s' % (npartitions, n, self._name)

        if npartitions > 1:
            name_p = 'head-partial-%d-%s' % (n, self._name)
            dsk = {(name_p, i): (M.head, (self._name, i), n)
                   for i in range(npartitions)}
            dsk[(name, 0)] = (M.head, (gd.concat, sorted(dsk)), n)
        else:
            dsk = {(name, 0): (M.head, (self._name, 0), n)}

        res = new_dd_object(merge(self.dask, dsk), name, self._meta,
                            (self.divisions[0], self.divisions[npartitions]))

        return res.compute() if compute else res
项目:dask_gdf    作者:gpuopenanalytics    | 项目源码 | 文件源码
def concat(objs):
    """Concantenate dask gdf objects

    Parameters
    ----------

    objs : sequence of DataFrame, Series, Index
        A sequence of objects to be concatenated.
    """
    objs = [_daskify(x) for x in objs]
    meta = gd.concat(_extract_meta(objs))

    name = "concat-" + uuid4().hex
    dsk = {}
    divisions = [0]
    base = 0
    lastdiv = 0
    for obj in objs:
        for k, i in obj._keys():
            dsk[name, base + i] = k, i
        base += obj.npartitions
        divisions.extend([d + lastdiv for d in obj.divisions[1:]])
        lastdiv = obj.divisions[-1]

    dasks = [o.dask for o in objs]
    dsk = merge(dsk, *dasks)
    return new_dd_object(dsk, name, meta, divisions)
项目:dask_gdf    作者:gpuopenanalytics    | 项目源码 | 文件源码
def __getitem__(self, key):
        if isinstance(key, str) and key in self.columns:
            meta = self._meta[key]
            name = 'getitem-%s' % tokenize(self, key)
            dsk = {(name, i): (operator.getitem, (self._name, i), key)
                   for i in range(self.npartitions)}
            return Series(merge(self.dask, dsk), name, meta, self.divisions)

        raise NotImplementedError("Indexing with %r" % key)
项目:mentos    作者:daskos    | 项目源码 | 文件源码
def __init__(self, scheduler, name, user=getpass.getuser(),
                 master=os.getenv('MESOS_MASTER', 'zk://localhost:2181'),
                 failover_timeout=100, capabilities=None, principal=None, secret=None,
                 implicit_acknowledgements=True, handlers={}, loop=None):
        self.loop = loop or IOLoop()
        self.master = master
        self.leading_master_seq = None
        self.leading_master_info = None

        self.scheduler = scheduler
        self.framework = {
            'user': user,
            'name': name,
            'capabilities': capabilities or [],
            'failover_timeout': failover_timeout,
            'hostname': socket.gethostname()
        }

        self.implicit_acknowledgements = implicit_acknowledgements

        defaults = {Event.SUBSCRIBED: self.on_subscribed,
                    Event.OFFERS: self.on_offers,
                    Event.RESCIND: self.on_rescind,
                    Event.UPDATE: self.on_update,
                    Event.MESSAGE: self.on_message,
                    Event.RESCIND_INVERSE_OFFER: self.on_rescind_inverse,
                    Event.FAILURE: self.on_failure,
                    Event.ERROR: self.on_error,
                    Event.HEARTBEAT: self.on_heartbeat,
                    Event.OUTBOUND_SUCCESS: self.on_outbound_success,
                    Event.OUTBOUND_ERROR: self.on_outbound_error}
        self.handlers = merge(defaults, handlers)

        self.subscription = Subscription(self.framework, self.master,
                                         '/api/v1/scheduler', self.handlers,
                                         principal=principal,
                                         secret=secret,
                                         timeout=failover_timeout,
                                         loop=self.loop)
项目:WindAdapter    作者:iLampard    | 项目源码 | 文件源码
def factor_load(start_date, end_date, factor_name, save_file=None, **kwargs):
    """
    :param start_date: str, ???????????
    :param end_date: str, ???????????
    :param factor_name: str, ???????????
    :param save_file: str, optional, ???????????? '*.csv' ?? '*.pkl'
    :param kwargs: dict, optional

            freq: str, optional, ???????? ??'M', 'W', 'S', 'Y'? ??enums.py - FreqType
            tenor: str, optional, ???????? ?????????????????????????(??)? ????+FreqType? ?'3M'
            sec_id, str/list, optional, ???????????
            output_data_format: enum, optional, ??enums.py - FreqType
                                MULTI_INDEX_DF: multi-index DataFrame, index=[date, secID], value = factor
                                PIVOT_TABLE_DF: DataFrame, index=date, columns = secID
            is_index: bool, optional, True: ???sec_id????????????????????????
                                      False: ????sec_id?????
            date_format: str, optional, ?????? ??'%Y-%m-%d'
    :return: pd.DataFrame ????????
    """
    if isinstance(factor_name, list):
        kwargs = merge(kwargs, {'output_data_format': OutputFormat.MULTI_INDEX_DF})
        factor_names = factor_name
    else:
        factor_names = [factor_name]

    ret = pd.DataFrame()
    for factor_name in factor_names:
        LOGGER.info('Loading factor data {0}'.format(factor_name))
        factor_loader = FactorLoader(start_date=start_date,
                                     end_date=end_date,
                                     factor_name=factor_name,
                                     **kwargs)
        factor_data = factor_loader.load_data()
        LOGGER.info('factor data {0} is loaded '.format(factor_name))
        ret = pd.concat([ret, factor_data], axis=1)
    if kwargs.get('reset_col_names'):
        ret.columns = factor_names
    if save_file:
        save_data_to_file(ret, save_file)
        LOGGER.critical('Data saved in {0}'.format(save_file))
    return ret
项目:catalyst    作者:enigmampc    | 项目源码 | 文件源码
def __init__(self, volume_limit, eta=ROOT_SYMBOL_TO_ETA):
        super(VolatilityVolumeShare, self).__init__()
        self.volume_limit = volume_limit

        # If 'eta' is a constant, use a dummy mapping to treat it as a
        # dictionary that always returns the same value.
        # NOTE: This dictionary does not handle unknown root symbols, so it may
        # be worth revisiting this behavior.
        if isinstance(eta, (int, float)):
            self._eta = DummyMapping(float(eta))
        else:
            # Eta is a dictionary. If the user's dictionary does not provide a
            # value for a certain contract, fall back on the pre-defined eta
            # values per root symbol.
            self._eta = merge(ROOT_SYMBOL_TO_ETA, eta)
项目:catalyst    作者:enigmampc    | 项目源码 | 文件源码
def __init__(self,
                 cost,
                 exchange_fee,
                 min_trade_cost=DEFAULT_MINIMUM_COST_PER_FUTURE_TRADE):
        # If 'cost' or 'exchange fee' are constants, use a dummy mapping to
        # treat them as a dictionary that always returns the same value.
        # NOTE: These dictionary does not handle unknown root symbols, so it
        # may be worth revisiting this behavior.
        if isinstance(cost, (int, float)):
            self._cost_per_contract = DummyMapping(float(cost))
        else:
            # Cost per contract is a dictionary. If the user's dictionary does
            # not provide a commission cost for a certain contract, fall back
            # on the pre-defined cost values per root symbol.
            self._cost_per_contract = defaultdict(
                lambda: DEFAULT_PER_CONTRACT_COST, **cost
            )

        if isinstance(exchange_fee, (int, float)):
            self._exchange_fee = DummyMapping(float(exchange_fee))
        else:
            # Exchange fee is a dictionary. If the user's dictionary does not
            # provide an exchange fee for a certain contract, fall back on the
            # pre-defined exchange fees per root symbol.
            self._exchange_fee = merge(
                FUTURE_EXCHANGE_FEES_BY_SYMBOL, exchange_fee,
            )

        self.min_trade_cost = min_trade_cost or 0
项目:catalyst    作者:enigmampc    | 项目源码 | 文件源码
def load_adjusted_array(self, columns, dates, sids, mask):
        n, p = self.split_next_and_previous_event_columns(columns)
        return merge(
            self.load_next_events(n, dates, sids, mask),
            self.load_previous_events(p, dates, sids, mask),
        )
项目:catalyst    作者:enigmampc    | 项目源码 | 文件源码
def test_ewm_stats(self, window_length):

        def ewma_name(decay_rate):
            return 'ewma_%s' % decay_rate

        def ewmstd_name(decay_rate):
            return 'ewmstd_%s' % decay_rate

        decay_rates = [0.25, 0.5, 0.75]
        ewmas = {
            ewma_name(decay_rate): EWMA(
                inputs=(USEquityPricing.close,),
                window_length=window_length,
                decay_rate=decay_rate,
            )
            for decay_rate in decay_rates
        }

        ewmstds = {
            ewmstd_name(decay_rate): EWMSTD(
                inputs=(USEquityPricing.close,),
                window_length=window_length,
                decay_rate=decay_rate,
            )
            for decay_rate in decay_rates
        }

        all_results = self.engine.run_pipeline(
            Pipeline(columns=merge(ewmas, ewmstds)),
            self.dates[window_length],
            self.dates[-1],
        )

        for decay_rate in decay_rates:
            ewma_result = all_results[ewma_name(decay_rate)].unstack()
            ewma_expected = self.expected_ewma(window_length, decay_rate)
            assert_frame_equal(ewma_result, ewma_expected)

            ewmstd_result = all_results[ewmstd_name(decay_rate)].unstack()
            ewmstd_expected = self.expected_ewmstd(window_length, decay_rate)
            assert_frame_equal(ewmstd_result, ewmstd_expected)
项目:ribosome    作者:tek    | 项目源码 | 文件源码
def json_repr(self):
        return merge(self.to_map(), dict(__type__=self.__path__))
项目:ribosome    作者:tek    | 项目源码 | 文件源码
def ribosome_file_logging(name: str, file_kw: dict=dict()) -> None:
    prefix_path = options.nvim_log_file.value | (lambda: amino.logging.log_dir() / 'nvim')
    level = (
        DDEBUG
        if options.development and options.spec else
        options.file_log_level.value | logging.DEBUG
    )
    logfile = Path(f'{prefix_path}_ribo_{name}_{os.getpid()}')
    kw = merge(
        file_kw,
        dict(level=level, logfile=logfile)
    )
    return amino_root_file_logging(**kw)
项目:calculette-impots-python    作者:openfisca    | 项目源码 | 文件源码
def main():
    global parser
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument('--output-html', help='Output result page HTML to file')
    parser.add_argument('--saisies', dest='saisie_variables', metavar='nom=valeur', nargs='+', help='Variables saisies')
    parser.add_argument('--year', default='2015', type=int,
                        help='Calculer les impôts de l\'année N sur les revenus de l\'année N-1')
    args = parser.parse_args()

    cgi_url = 'http://www3.finances.gouv.fr/cgi-bin/calc-{}.cgi'.format(args.year)
    headers = {'User-Agent': 'Calculette-Impots-Python'}
    saisie_variables = {} if args.saisie_variables is None else dict(iter_saisie_variables(args.saisie_variables))
    default_saisie_variables = {
        # '0DA': '1965',
        # '1AJ': '15000',
        'pre_situation_famille': 'C',
        'pre_situation_residence': 'M',
        # 'simplifie': '1',
        }
    data = merge(default_saisie_variables, saisie_variables)
    response = requests.post(cgi_url, headers=headers, data=data)
    if args.output_html is not None:
        with open(args.output_html, 'w') as output_html_file:
            output_html_file.write(re.sub(
                pattern=r'=(.)/calcul_impot/2015/',
                repl=r'=\1http://www3.finances.gouv.fr/calcul_impot/2015/',
                string=response.text,
                ))
    root_node = etree.fromstring(response.text, etree.HTMLParser())
    results = list(iter_results(root_node))
    print(json.dumps(results, ensure_ascii=False, indent=2, sort_keys=True))

    return 0
项目:nekpy    作者:NekBox    | 项目源码 | 文件源码
def run_all(values, base, get=get_proc, num_workers = 4):
    full_dask = toolz.merge(val.dask for val in values)
    full_keys = [val._key for val in values]

    cache = {}
    if exists("{}.cache".format(base["prefix"])):
        with open("{}.cache".format(base["prefix"]), "r") as f:
            cache = json.load(f)

    full_dask.update(cache)

    with ProgressBar(), NekCallback(base) as rprof:
        res = get(full_dask, full_keys, cache=cache, num_workers=num_workers, optimize_graph=False)

    return res
项目:daskos    作者:daskos    | 项目源码 | 文件源码
def params(self):
        return merge(*self._params)
项目:gpu_monitor    作者:msalvaris    | 项目源码 | 文件源码
def _influxdb_writer_for(influxdb_client, measurement):
    mes_dict = {"measurement": measurement}
    def to_influxdf(*data_dicts):
        merged_dicts = merge(mes_dict, *data_dicts)
        logger.debug(merged_dicts)
        if influxdb_client.write_points([merged_dicts]):
            logger.debug("Success")
        else:
            logger.info("FAIL")
    return to_influxdf
项目:gpu_monitor    作者:msalvaris    | 项目源码 | 文件源码
def _add_tags(tags, json_dict):
    json_dict['tags'] = merge(json_dict['tags'], tags)
    return json_dict
项目:call_map    作者:nccgroup    | 项目源码 | 文件源码
def get_config(self):
        if self._cache is None:
            read_result = self.read_user_config()
            return tz.merge(self.default_user_config,
                            read_result,
                            self.session_overrides)
        else:
            return self._cache
项目:call_map    作者:nccgroup    | 项目源码 | 文件源码
def configure_role_markers(cls, want_unicode_role_markers):
        if want_unicode_role_markers:
            cls.role_markers = tz.merge(cls.default_role_markers, {'definition': '?', 'parent': '?'})
        else:
            cls.role_markers = cls.default_role_markers
项目:sgnmt    作者:ucam-smt    | 项目源码 | 文件源码
def apply(self, source_sentence, source_sentence_mask):
        """Produces source annotations, either non-recurrently or with
        a bidirectional RNN architecture.
        """
        # Time as first dimension
        source_sentence = source_sentence.T
        source_sentence_mask = source_sentence_mask.T

        embeddings = self.lookup.apply(source_sentence)

        if self.n_layers >= 1:
            representation = self.bidir.apply(
                merge(self.fwd_fork.apply(embeddings, as_dict=True),
                      {'mask': source_sentence_mask}),
                merge(self.back_fork.apply(embeddings, as_dict=True),
                      {'mask': source_sentence_mask})
            )
            for _ in xrange(self.n_layers-1):
                if self.skip_connections:
                    inp = tensor.concatenate([representation, embeddings],
                                             axis=2)
                else:
                    inp = representation
                representation = self.bidir.apply(
                    merge(self.mid_fwd_fork.apply(inp, as_dict=True),
                          {'mask': source_sentence_mask}),
                    merge(self.mid_back_fork.apply(inp, as_dict=True),
                          {'mask': source_sentence_mask})
                )
        else:
            representation = embeddings
        return representation, source_sentence_mask
项目:sgnmt    作者:ucam-smt    | 项目源码 | 文件源码
def apply(self, source_sentence, source_sentence_mask):
        """Produces source annotations, either non-recurrently or with
        a bidirectional RNN architecture.
        """
        # Time as first dimension
        source_sentence = source_sentence.T
        source_sentence_mask = source_sentence_mask.T
        embeddings = self.lookup.apply(source_sentence)
        representation = self.bidirs[0].apply(
                merge(self.fwd_forks[0].apply(embeddings, as_dict=True),
                      {'mask': source_sentence_mask}),
                merge(self.back_forks[0].apply(embeddings, as_dict=True),
                      {'mask': source_sentence_mask}))
        for i in xrange(1, self.n_layers):
            if self.skip_connections:
                inp = tensor.concatenate([representation, embeddings],
                                         axis=2)
            else:
                inp = representation
            representation = self.bidirs[i].apply(
                merge(self.fwd_forks[i].apply(inp, as_dict=True),
                      {'mask': source_sentence_mask}),
                merge(self.back_forks[i].apply(inp, as_dict=True),
                      {'mask': source_sentence_mask})
            )
        return representation, source_sentence_mask
项目:sgnmt    作者:ucam-smt    | 项目源码 | 文件源码
def apply(self, base_annotations, base_mask):
        ann_representation = self.transition.apply(
            **merge(self.rnn_inputs, {
                'mask': base_mask,
                'attended': base_annotations,
                'attended_mask': base_mask}))[0]
        return ann_representation, base_mask
项目:DCNMT    作者:SwordYork    | 项目源码 | 文件源码
def apply(self, char_seq, sample_matrix, char_aux):
        # Time as first dimension
        embeddings = self.lookup.apply(char_seq)
        gru_out = self.dgru.apply(
            **merge(self.gru_fork.apply(embeddings, as_dict=True),
                    {'mask': char_aux}))
        wgru_out = tensor.exp(self.wl.apply(self.bidir_w.apply(embeddings, char_aux)))

        if self.dgru_depth > 1:
            gru_out = gru_out[-1]

        gru_out = tensor.addbroadcast(wgru_out, 2) * gru_out
        sampled_representation = tensor.tanh(tensor.batched_dot(sample_matrix, gru_out.dimshuffle([1, 0, 2])))
        return sampled_representation.dimshuffle([1, 0, 2]), wgru_out
项目:DCNMT    作者:SwordYork    | 项目源码 | 文件源码
def apply(self, char_seq, sample_matrix, char_aux):
        # Time as first dimension
        embeddings = self.lookup.apply(char_seq)
        gru_out = self.dgru.apply(
            **merge(self.gru_fork.apply(embeddings, as_dict=True),
                    {'mask': char_aux}))
        if self.dgru_depth > 1:
            gru_out = gru_out[-1]
        sampled_representation = tensor.batched_dot(sample_matrix, gru_out.dimshuffle([1, 0, 2]))
        return sampled_representation.dimshuffle([1, 0, 2])
项目:DCNMT    作者:SwordYork    | 项目源码 | 文件源码
def single_emit(self, target_single_char, batch_size, mask, states=None):
        # Time as first dimension
        # only one batch
        embeddings = self.lookup.apply(target_single_char)
        if states is None:
            states = self.dgru.initial_states(batch_size)
        states_dict = {'states': states[0]}
        for i in range(1, self.dgru_depth):
            states_dict['states' + RECURRENTSTACK_SEPARATOR + str(i)] = states[i]
        gru_out = self.dgru.apply(**merge(self.gru_fork.apply(embeddings, as_dict=True), states_dict,
                                          {'mask': mask, 'iterate': False}))
        return gru_out
项目:DCNMT    作者:SwordYork    | 项目源码 | 文件源码
def __init__(self, vocab_size, embedding_dim, igru_state_dim, igru_depth, trg_dgru_depth, emitter,
                 feedback_brick, merge=None, merge_prototype=None, post_merge=None, **kwargs):
        merged_dim = igru_state_dim
        if not merge:
            merge = Merge(input_names=kwargs['source_names'],
                          prototype=merge_prototype)
        if not post_merge:
            post_merge = Bias(dim=merged_dim)

        # for compatible
        if igru_depth == 1:
            self.igru = IGRU(dim=igru_state_dim)
        else:
            self.igru = RecurrentStack([IGRU(dim=igru_state_dim, name='igru')] +
                                       [UpperIGRU(dim=igru_state_dim, activation=Tanh(), name='upper_igru' + str(i))
                                        for i in range(1, igru_depth)],
                                       skip_connections=True)
        self.embedding_dim = embedding_dim
        self.emitter = emitter
        self.feedback_brick = feedback_brick
        self.merge = merge
        self.post_merge = post_merge
        self.merged_dim = merged_dim
        self.igru_depth = igru_depth
        self.trg_dgru_depth = trg_dgru_depth
        self.lookup = LookupTable(name='embeddings')
        self.vocab_size = vocab_size
        self.igru_state_dim = igru_state_dim
        self.gru_to_softmax = Linear(input_dim=igru_state_dim, output_dim=vocab_size)
        self.gru_fork = Fork([name for name in self.igru.apply.sequences
                              if name != 'mask' and name != 'input_states'], prototype=Linear(), name='gru_fork')

        children = [self.emitter, self.feedback_brick, self.merge, self.post_merge,
                    self.igru, self.lookup, self.gru_to_softmax, self.gru_fork]
        kwargs.setdefault('children', []).extend(children)
        super(Interpolator, self).__init__(**kwargs)
项目:DCNMT    作者:SwordYork    | 项目源码 | 文件源码
def _push_allocation_config(self):
        self.lookup.length = self.vocab_size
        self.lookup.dim = self.embedding_dim
        self.emitter.readout_dim = self.get_dim('readouts')
        self.merge.input_names = self.source_names
        self.merge.input_dims = self.source_dims
        self.merge.output_dim = self.merged_dim
        self.post_merge.input_dim = self.merged_dim
        self.post_merge.output_dim = self.igru_state_dim
        self.gru_fork.input_dim = self.embedding_dim
        self.gru_fork.output_dims = [self.igru.get_dim(name)
                                     for name in self.gru_fork.output_names]
项目:DCNMT    作者:SwordYork    | 项目源码 | 文件源码
def readout(self, **kwargs):
        merged = self.merge.apply(**{name: kwargs[name]
                                     for name in self.merge.input_names})
        merged = self.post_merge.apply(merged)
        return merged
项目:DCNMT    作者:SwordYork    | 项目源码 | 文件源码
def readout_gru(self, target_prev_char_seq, target_prev_char_aux, input_states):
        embeddings = self.lookup.apply(target_prev_char_seq)
        gru_out = self.igru.apply(
            **merge(self.gru_fork.apply(embeddings, as_dict=True),
                    {'mask': target_prev_char_aux, 'input_states': input_states}))
        if self.igru_depth > 1:
            gru_out = gru_out[-1]
        readout_chars = self.gru_to_softmax.apply(gru_out)
        return readout_chars
项目:dl4mt-multi-src    作者:nyu-dl    | 项目源码 | 文件源码
def do(self, which_callback, *args):
        iterations_done = self.main_loop.status['iterations_done']
        if self.burnin <= iterations_done:
            # Save the model here
            iterations_done = self.main_loop.status['iterations_done']
            filename = os.path.join(
                self.saveto, 'params_iter{}.npz'.format(iterations_done))
            s = signal.signal(signal.SIGINT, signal.SIG_IGN)
            logger.info(" Incremental dump {}".format(filename))
            params_to_save = []
            for cg_name in self.main_loop.models.keys():
                params_to_save.append(
                    self.main_loop.models[cg_name].get_param_values())
            params_to_save = merge(params_to_save)
            secure_numpy_save(params_to_save, filename)
            if self.save_iter_state:
                filename_is = os.path.join(
                    self.saveto,
                    'iterations_state_iter{}.pkl'.format(iterations_done))
                logger.info(" Incremental dump {}".format(filename_is))
                secure_pickle_dump(self.main_loop.iteration_state, filename_is)
            if self.save_log:
                filename_log = os.path.join(
                    self.saveto,
                    'log_iter{}'.format(iterations_done))
                logger.info(" Incremental dump {}".format(filename_log))
                secure_pickle_dump(self.main_loop.log, filename_log)
            signal.signal(signal.SIGINT, s)
项目:dl4mt-multi-src    作者:nyu-dl    | 项目源码 | 文件源码
def dump_parameters(self, main_loop):
        params_to_save = []
        for model in main_loop.models.values():
            params_to_save.append(model.get_param_values())
        secure_numpy_save(merge(params_to_save),
                          self.path_to_parameters)
项目:dl4mt-multi-src    作者:nyu-dl    | 项目源码 | 文件源码
def get_params(self):
        return merge(self.encoder.get_params(),
                     self.decoder.get_params())
项目:amino    作者:tek    | 项目源码 | 文件源码
def __call(self, a, kw):
        sub_a, rest = self.__substitute__(self.__args, List.wrap(a))
        sub_kw = merge(self.__kwargs, kw)
        return self.__func(*sub_a, **sub_kw), rest
项目:provenance    作者:bmabey    | 项目源码 | 文件源码
def artifact_record(**kargs):
    artifact_props = t.merge({k: None for k in  pc.artifact_properties},
                             _artifact_record_st.example(),
                             {'inputs': {'varargs':[1,2,3], 'kargs': {}},
                              'fn_module': 'foo', 'fn_name': 'bar',
                              'value': 55, 'name': 'bar',
                              'version': 0,
                              'serializer': 'joblib',
                              'run_info': pc.run_info()},
                             kargs)
    return pc.ArtifactRecord(**artifact_props)
项目:provenance    作者:bmabey    | 项目源码 | 文件源码
def __repr__(self):
        return "lazy_dict({})".format(
            t.merge(t.valmap(lambda _: "...", self.thunks), self.realized))
项目:provenance    作者:bmabey    | 项目源码 | 文件源码
def lazy_proxy_dict(artifacts_or_ids, group_artifacts_of_same_name=False):
    """
    Takes a list of artifacts or artifact ids and returns a dictionary whose
    keys are the names of the artifacts. The values will be lazily loaded into
    proxies as requested.

    Parameters
    ----------
    artifacts_or_ids : collection of artifacts or artifact ids (strings)

    group_artifacts_of_same_name: bool (default: False)
    If set to True then artifacts of the same name will be grouped together in
    one list. When set to False an exception will be raised
    """
    if isinstance(artifacts_or_ids, dict):
        artifacts = t.valmap(coerce_to_artifact, artifacts_or_ids)
        lambdas = {name: (lambda a: lambda: a.proxy())(a)
                   for name, a in artifacts.items()}
        return lazy_dict(lambdas)

    # else we have a collection
    artifacts = coerce_to_artifacts(artifacts_or_ids)
    by_name = t.groupby(lambda a: a.name, artifacts)
    singles = t.valfilter(lambda l: len(l) == 1, by_name)
    multi = t.valfilter(lambda l: len(l) > 1, by_name)

    lambdas = {name: (lambda a: lambda: a.proxy())(a[0]) for name, a in singles.items()}

    if group_artifacts_of_same_name and len(multi) > 0:
        lambdas = t.merge(lambdas,
                          {name:
                           (lambda artifacts: (lambda: [a.proxy() for a in artifacts]))(artifacts)
                           for name, artifacts in multi.items()})

    if not group_artifacts_of_same_name and len(multi) > 0:
        raise ValueError("""Only artifacts with distinct names can be used in a lazy_proxy_dict.
Offending names: {}
Use the option `group_artifacts_of_same_name=True` if you want a list of proxies to be returned under the respective key.
        """.format({n: len(a) for n, a in multi.items()}))

    return lazy_dict(lambdas)
项目:provenance    作者:bmabey    | 项目源码 | 文件源码
def register_custom_objects(mapping, merge=False):
    if merge:
        res = t.merge(REGISTERED_CUSTOM_OBJECTS, mapping)
    else:
        res = mapping

    REGISTERED_CUSTOM_OBJECTS = res



#TODO: move custom_objects into the attrs