我们从Python开源项目中,提取了以下50个代码示例,用于说明如何使用toolz.merge()。
def print_training_params(self, cgs, training_params): enc_dec_param_dict = merge(self.encoder.get_params(), self.decoder.get_params()) # Print which parameters are excluded for k, v in cgs.iteritems(): excluded_all = list(set(v.parameters) - set(training_params[k])) for p in excluded_all: logger.info( 'Excluding from training of CG[{}]: {}' .format(k, [key for key, val in enc_dec_param_dict.iteritems() if val == p][0])) logger.info( 'Total number of excluded parameters for CG[{}]: [{}]' .format(k, len(excluded_all))) for k, v in training_params.iteritems(): for p in v: logger.info('Training parameter from CG[{}]: {}' .format(k, p.name)) logger.info( 'Total number of parameters will be trained for CG[{}]: [{}]' .format(k, len(v)))
def scrutinize(self, scrutine, context_frame): constructor = type(scrutine) if constructor.__name__ != self._constructor_name: raise NoMatch() kwargs = scrutine._kwargs # the context to evaluate the thunk in context = { Call(Normal(name_lookup), (Normal(name),), {}): Normal(value) for name, value in merge( vars(builtins), context_frame.f_globals, context_frame.f_locals, # the newly bound arguments have the highest precedence dict(zip(self._argnames, scrutine._args)), {v: kwargs[k] for k, v in self._kwargnames.items()}, ).items() } bound_tree = LTree.parse(self._expr).subs(context) return strict(bound_tree.lcompile())
def run_example(example_name, environ): """ Run an example module from catalyst.examples. """ mod = EXAMPLE_MODULES[example_name] register_calendar("YAHOO", get_calendar("NYSE"), force=True) return run_algorithm( initialize=getattr(mod, 'initialize', None), handle_data=getattr(mod, 'handle_data', None), before_trading_start=getattr(mod, 'before_trading_start', None), analyze=getattr(mod, 'analyze', None), bundle='test', environ=environ, # Provide a default capital base, but allow the test to override. **merge({'capital_base': 1e7}, mod._test_args()) )
def prepare(query, connection=None, external=None): connection = merge(_default, connection or {}) database = escape(connection['database']) query = query.format(db=database) params = {'query': query, 'user': connection['user'], 'password': connection['password']} params = valfilter(lambda x: x, params) files = {} external = external or {} for name, (structure, serialized) in external.items(): params['{}_format'.format(name)] = 'CSV' params['{}_structure'.format(name)] = structure files[name] = serialized host = connection['host'] return host, params, files
def apply(self, source_sentence, source_sentence_mask): """Creates bidirectional RNN source annotations. Args: source_sentence (Variable): Source sentence with words in vector representation. source_sentence_mask (Variable): Source mask Returns: Variable. source annotations """ # Time as first dimension source_sentence = source_sentence.T source_sentence_mask = source_sentence_mask.T representation = self.bidir.apply( merge(self.fwd_fork.apply(source_sentence, as_dict=True), {'mask': source_sentence_mask}), merge(self.back_fork.apply(source_sentence, as_dict=True), {'mask': source_sentence_mask}) ) return representation, source_sentence_mask
def __init__(self, key, mapping=None, **kwargs): self._map = {} self._sorted_key_names = [] self._sort_key = key self.update(merge(mapping or {}, kwargs))
def load_adjusted_array(self, columns, dates, assets, mask): return merge( self.get_loader(column).load_adjusted_array( [column], dates, assets, mask ) for column in columns )
def test_ewm_stats(self, window_length): def ewma_name(decay_rate): return 'ewma_%s' % decay_rate def ewmstd_name(decay_rate): return 'ewmstd_%s' % decay_rate decay_rates = [0.25, 0.5, 0.75] ewmas = { ewma_name(decay_rate): EWMA( inputs=(USEquityPricing.close,), window_length=window_length, decay_rate=decay_rate, ) for decay_rate in decay_rates } ewmstds = { ewmstd_name(decay_rate): EWMSTD( inputs=(USEquityPricing.close,), window_length=window_length, decay_rate=decay_rate, ) for decay_rate in decay_rates } all_results = self.engine.run_pipeline( Pipeline(columns=merge(ewmas, ewmstds)), self.dates[window_length], self.dates[-1], ) for decay_rate in decay_rates: ewma_result = all_results[ewma_name(decay_rate)].unstack() ewma_expected = self.expected_ewma(window_length, decay_rate) assert_frame_equal(ewma_result, ewma_expected) ewmstd_result = all_results[ewmstd_name(decay_rate)].unstack() ewmstd_expected = self.expected_ewmstd(window_length, decay_rate) assert_frame_equal(ewmstd_result, ewmstd_expected)
def do(self, which_callback, *args): iterations_done = self.main_loop.status['iterations_done'] if self.burnin <= iterations_done: # Save the model here iterations_done = self.main_loop.status['iterations_done'] filename = os.path.join( self.saveto, 'params_iter{}.npz'.format(iterations_done)) s = signal.signal(signal.SIGINT, signal.SIG_IGN) logger.info(" Incremental dump {}".format(filename)) params_to_save = [] for cg_name in self.main_loop.models.keys(): params_to_save.append( self.main_loop.models[cg_name].get_param_values()) params_to_save = merge(params_to_save) secure_numpy_save(params_to_save, filename) if self.save_iter_state: filename_is = os.path.join( self.saveto, 'iterations_state_iter{}.pkl'.format(iterations_done)) logger.info(" Incremental dump {}".format(filename_is)) secure_pickle_dump(self.main_loop.iteration_state, filename_is) if self.save_log: filename_log = os.path.join( self.saveto, 'log_iter{}'.format(iterations_done)) logger.info(" Incremental dump {}".format(filename_log)) secure_pickle_dump(self.main_loop.log, filename_log) signal.signal(signal.SIGINT, s)
def dump_parameters(self, main_loop): params_to_save = [] for model in main_loop.models.values(): params_to_save.append(model.get_param_values()) secure_numpy_save(merge(params_to_save), self.path_to_parameters)
def get_params(self): return merge(self.encoder.get_params(), self.decoder.get_params())
def _save_model(self, bleu_score): if self._is_valid_to_save(bleu_score): model = ModelInfo( bleu_score, self.saveto, self.enc_id, self.dec_id) # Manage n-best model list first if len(self.best_models) >= self.track_n_models: old_model = self.best_models[0] if old_model.path and os.path.isfile(old_model.path): logger.info("Deleting old model %s" % old_model.path) os.remove(old_model.path) self.best_models.remove(old_model) self.best_models.append(model) self.best_models.sort(key=operator.attrgetter('bleu_score')) # Save the model here s = signal.signal(signal.SIGINT, signal.SIG_IGN) logger.info("Saving new model {}".format(model.path)) params_to_save = [] for cg_name in self.main_loop.models.keys(): params_to_save.append( self.main_loop.models[cg_name].get_param_values()) params_to_save = merge(params_to_save) self._save_params(model, params_to_save) self._save_bleu_scores() signal.signal(signal.SIGINT, s)
def index(self): """Return dask Index instance""" name = self._name + '-index' dsk = {(name, i): (getattr, key, 'index') for i, key in enumerate(self._keys())} return Index(merge(dsk, self.dask), name, self._meta.index, self.divisions)
def head(self, n=5, npartitions=1, compute=True): """ First n rows of the dataset Parameters ---------- n : int, optional The number of rows to return. Default is 5. npartitions : int, optional Elements are only taken from the first ``npartitions``, with a default of 1. If there are fewer than ``n`` rows in the first ``npartitions`` a warning will be raised and any found rows returned. Pass -1 to use all partitions. compute : bool, optional Whether to compute the result, default is True. """ if npartitions <= -1: npartitions = self.npartitions if npartitions > self.npartitions: raise ValueError("only %d partitions, received " "%d" % (self.npartitions, npartitions)) name = 'head-%d-%d-%s' % (npartitions, n, self._name) if npartitions > 1: name_p = 'head-partial-%d-%s' % (n, self._name) dsk = {(name_p, i): (M.head, (self._name, i), n) for i in range(npartitions)} dsk[(name, 0)] = (M.head, (gd.concat, sorted(dsk)), n) else: dsk = {(name, 0): (M.head, (self._name, 0), n)} res = new_dd_object(merge(self.dask, dsk), name, self._meta, (self.divisions[0], self.divisions[npartitions])) return res.compute() if compute else res
def concat(objs): """Concantenate dask gdf objects Parameters ---------- objs : sequence of DataFrame, Series, Index A sequence of objects to be concatenated. """ objs = [_daskify(x) for x in objs] meta = gd.concat(_extract_meta(objs)) name = "concat-" + uuid4().hex dsk = {} divisions = [0] base = 0 lastdiv = 0 for obj in objs: for k, i in obj._keys(): dsk[name, base + i] = k, i base += obj.npartitions divisions.extend([d + lastdiv for d in obj.divisions[1:]]) lastdiv = obj.divisions[-1] dasks = [o.dask for o in objs] dsk = merge(dsk, *dasks) return new_dd_object(dsk, name, meta, divisions)
def __getitem__(self, key): if isinstance(key, str) and key in self.columns: meta = self._meta[key] name = 'getitem-%s' % tokenize(self, key) dsk = {(name, i): (operator.getitem, (self._name, i), key) for i in range(self.npartitions)} return Series(merge(self.dask, dsk), name, meta, self.divisions) raise NotImplementedError("Indexing with %r" % key)
def __init__(self, scheduler, name, user=getpass.getuser(), master=os.getenv('MESOS_MASTER', 'zk://localhost:2181'), failover_timeout=100, capabilities=None, principal=None, secret=None, implicit_acknowledgements=True, handlers={}, loop=None): self.loop = loop or IOLoop() self.master = master self.leading_master_seq = None self.leading_master_info = None self.scheduler = scheduler self.framework = { 'user': user, 'name': name, 'capabilities': capabilities or [], 'failover_timeout': failover_timeout, 'hostname': socket.gethostname() } self.implicit_acknowledgements = implicit_acknowledgements defaults = {Event.SUBSCRIBED: self.on_subscribed, Event.OFFERS: self.on_offers, Event.RESCIND: self.on_rescind, Event.UPDATE: self.on_update, Event.MESSAGE: self.on_message, Event.RESCIND_INVERSE_OFFER: self.on_rescind_inverse, Event.FAILURE: self.on_failure, Event.ERROR: self.on_error, Event.HEARTBEAT: self.on_heartbeat, Event.OUTBOUND_SUCCESS: self.on_outbound_success, Event.OUTBOUND_ERROR: self.on_outbound_error} self.handlers = merge(defaults, handlers) self.subscription = Subscription(self.framework, self.master, '/api/v1/scheduler', self.handlers, principal=principal, secret=secret, timeout=failover_timeout, loop=self.loop)
def factor_load(start_date, end_date, factor_name, save_file=None, **kwargs): """ :param start_date: str, ??????????? :param end_date: str, ??????????? :param factor_name: str, ??????????? :param save_file: str, optional, ???????????? '*.csv' ?? '*.pkl' :param kwargs: dict, optional freq: str, optional, ???????? ??'M', 'W', 'S', 'Y'? ??enums.py - FreqType tenor: str, optional, ???????? ?????????????????????????(??)? ????+FreqType? ?'3M' sec_id, str/list, optional, ??????????? output_data_format: enum, optional, ??enums.py - FreqType MULTI_INDEX_DF: multi-index DataFrame, index=[date, secID], value = factor PIVOT_TABLE_DF: DataFrame, index=date, columns = secID is_index: bool, optional, True: ???sec_id???????????????????????? False: ????sec_id????? date_format: str, optional, ?????? ??'%Y-%m-%d' :return: pd.DataFrame ???????? """ if isinstance(factor_name, list): kwargs = merge(kwargs, {'output_data_format': OutputFormat.MULTI_INDEX_DF}) factor_names = factor_name else: factor_names = [factor_name] ret = pd.DataFrame() for factor_name in factor_names: LOGGER.info('Loading factor data {0}'.format(factor_name)) factor_loader = FactorLoader(start_date=start_date, end_date=end_date, factor_name=factor_name, **kwargs) factor_data = factor_loader.load_data() LOGGER.info('factor data {0} is loaded '.format(factor_name)) ret = pd.concat([ret, factor_data], axis=1) if kwargs.get('reset_col_names'): ret.columns = factor_names if save_file: save_data_to_file(ret, save_file) LOGGER.critical('Data saved in {0}'.format(save_file)) return ret
def __init__(self, volume_limit, eta=ROOT_SYMBOL_TO_ETA): super(VolatilityVolumeShare, self).__init__() self.volume_limit = volume_limit # If 'eta' is a constant, use a dummy mapping to treat it as a # dictionary that always returns the same value. # NOTE: This dictionary does not handle unknown root symbols, so it may # be worth revisiting this behavior. if isinstance(eta, (int, float)): self._eta = DummyMapping(float(eta)) else: # Eta is a dictionary. If the user's dictionary does not provide a # value for a certain contract, fall back on the pre-defined eta # values per root symbol. self._eta = merge(ROOT_SYMBOL_TO_ETA, eta)
def __init__(self, cost, exchange_fee, min_trade_cost=DEFAULT_MINIMUM_COST_PER_FUTURE_TRADE): # If 'cost' or 'exchange fee' are constants, use a dummy mapping to # treat them as a dictionary that always returns the same value. # NOTE: These dictionary does not handle unknown root symbols, so it # may be worth revisiting this behavior. if isinstance(cost, (int, float)): self._cost_per_contract = DummyMapping(float(cost)) else: # Cost per contract is a dictionary. If the user's dictionary does # not provide a commission cost for a certain contract, fall back # on the pre-defined cost values per root symbol. self._cost_per_contract = defaultdict( lambda: DEFAULT_PER_CONTRACT_COST, **cost ) if isinstance(exchange_fee, (int, float)): self._exchange_fee = DummyMapping(float(exchange_fee)) else: # Exchange fee is a dictionary. If the user's dictionary does not # provide an exchange fee for a certain contract, fall back on the # pre-defined exchange fees per root symbol. self._exchange_fee = merge( FUTURE_EXCHANGE_FEES_BY_SYMBOL, exchange_fee, ) self.min_trade_cost = min_trade_cost or 0
def load_adjusted_array(self, columns, dates, sids, mask): n, p = self.split_next_and_previous_event_columns(columns) return merge( self.load_next_events(n, dates, sids, mask), self.load_previous_events(p, dates, sids, mask), )
def json_repr(self): return merge(self.to_map(), dict(__type__=self.__path__))
def ribosome_file_logging(name: str, file_kw: dict=dict()) -> None: prefix_path = options.nvim_log_file.value | (lambda: amino.logging.log_dir() / 'nvim') level = ( DDEBUG if options.development and options.spec else options.file_log_level.value | logging.DEBUG ) logfile = Path(f'{prefix_path}_ribo_{name}_{os.getpid()}') kw = merge( file_kw, dict(level=level, logfile=logfile) ) return amino_root_file_logging(**kw)
def main(): global parser parser = argparse.ArgumentParser(description=__doc__) parser.add_argument('--output-html', help='Output result page HTML to file') parser.add_argument('--saisies', dest='saisie_variables', metavar='nom=valeur', nargs='+', help='Variables saisies') parser.add_argument('--year', default='2015', type=int, help='Calculer les impôts de l\'année N sur les revenus de l\'année N-1') args = parser.parse_args() cgi_url = 'http://www3.finances.gouv.fr/cgi-bin/calc-{}.cgi'.format(args.year) headers = {'User-Agent': 'Calculette-Impots-Python'} saisie_variables = {} if args.saisie_variables is None else dict(iter_saisie_variables(args.saisie_variables)) default_saisie_variables = { # '0DA': '1965', # '1AJ': '15000', 'pre_situation_famille': 'C', 'pre_situation_residence': 'M', # 'simplifie': '1', } data = merge(default_saisie_variables, saisie_variables) response = requests.post(cgi_url, headers=headers, data=data) if args.output_html is not None: with open(args.output_html, 'w') as output_html_file: output_html_file.write(re.sub( pattern=r'=(.)/calcul_impot/2015/', repl=r'=\1http://www3.finances.gouv.fr/calcul_impot/2015/', string=response.text, )) root_node = etree.fromstring(response.text, etree.HTMLParser()) results = list(iter_results(root_node)) print(json.dumps(results, ensure_ascii=False, indent=2, sort_keys=True)) return 0
def run_all(values, base, get=get_proc, num_workers = 4): full_dask = toolz.merge(val.dask for val in values) full_keys = [val._key for val in values] cache = {} if exists("{}.cache".format(base["prefix"])): with open("{}.cache".format(base["prefix"]), "r") as f: cache = json.load(f) full_dask.update(cache) with ProgressBar(), NekCallback(base) as rprof: res = get(full_dask, full_keys, cache=cache, num_workers=num_workers, optimize_graph=False) return res
def params(self): return merge(*self._params)
def _influxdb_writer_for(influxdb_client, measurement): mes_dict = {"measurement": measurement} def to_influxdf(*data_dicts): merged_dicts = merge(mes_dict, *data_dicts) logger.debug(merged_dicts) if influxdb_client.write_points([merged_dicts]): logger.debug("Success") else: logger.info("FAIL") return to_influxdf
def _add_tags(tags, json_dict): json_dict['tags'] = merge(json_dict['tags'], tags) return json_dict
def get_config(self): if self._cache is None: read_result = self.read_user_config() return tz.merge(self.default_user_config, read_result, self.session_overrides) else: return self._cache
def configure_role_markers(cls, want_unicode_role_markers): if want_unicode_role_markers: cls.role_markers = tz.merge(cls.default_role_markers, {'definition': '?', 'parent': '?'}) else: cls.role_markers = cls.default_role_markers
def apply(self, source_sentence, source_sentence_mask): """Produces source annotations, either non-recurrently or with a bidirectional RNN architecture. """ # Time as first dimension source_sentence = source_sentence.T source_sentence_mask = source_sentence_mask.T embeddings = self.lookup.apply(source_sentence) if self.n_layers >= 1: representation = self.bidir.apply( merge(self.fwd_fork.apply(embeddings, as_dict=True), {'mask': source_sentence_mask}), merge(self.back_fork.apply(embeddings, as_dict=True), {'mask': source_sentence_mask}) ) for _ in xrange(self.n_layers-1): if self.skip_connections: inp = tensor.concatenate([representation, embeddings], axis=2) else: inp = representation representation = self.bidir.apply( merge(self.mid_fwd_fork.apply(inp, as_dict=True), {'mask': source_sentence_mask}), merge(self.mid_back_fork.apply(inp, as_dict=True), {'mask': source_sentence_mask}) ) else: representation = embeddings return representation, source_sentence_mask
def apply(self, source_sentence, source_sentence_mask): """Produces source annotations, either non-recurrently or with a bidirectional RNN architecture. """ # Time as first dimension source_sentence = source_sentence.T source_sentence_mask = source_sentence_mask.T embeddings = self.lookup.apply(source_sentence) representation = self.bidirs[0].apply( merge(self.fwd_forks[0].apply(embeddings, as_dict=True), {'mask': source_sentence_mask}), merge(self.back_forks[0].apply(embeddings, as_dict=True), {'mask': source_sentence_mask})) for i in xrange(1, self.n_layers): if self.skip_connections: inp = tensor.concatenate([representation, embeddings], axis=2) else: inp = representation representation = self.bidirs[i].apply( merge(self.fwd_forks[i].apply(inp, as_dict=True), {'mask': source_sentence_mask}), merge(self.back_forks[i].apply(inp, as_dict=True), {'mask': source_sentence_mask}) ) return representation, source_sentence_mask
def apply(self, base_annotations, base_mask): ann_representation = self.transition.apply( **merge(self.rnn_inputs, { 'mask': base_mask, 'attended': base_annotations, 'attended_mask': base_mask}))[0] return ann_representation, base_mask
def apply(self, char_seq, sample_matrix, char_aux): # Time as first dimension embeddings = self.lookup.apply(char_seq) gru_out = self.dgru.apply( **merge(self.gru_fork.apply(embeddings, as_dict=True), {'mask': char_aux})) wgru_out = tensor.exp(self.wl.apply(self.bidir_w.apply(embeddings, char_aux))) if self.dgru_depth > 1: gru_out = gru_out[-1] gru_out = tensor.addbroadcast(wgru_out, 2) * gru_out sampled_representation = tensor.tanh(tensor.batched_dot(sample_matrix, gru_out.dimshuffle([1, 0, 2]))) return sampled_representation.dimshuffle([1, 0, 2]), wgru_out
def apply(self, char_seq, sample_matrix, char_aux): # Time as first dimension embeddings = self.lookup.apply(char_seq) gru_out = self.dgru.apply( **merge(self.gru_fork.apply(embeddings, as_dict=True), {'mask': char_aux})) if self.dgru_depth > 1: gru_out = gru_out[-1] sampled_representation = tensor.batched_dot(sample_matrix, gru_out.dimshuffle([1, 0, 2])) return sampled_representation.dimshuffle([1, 0, 2])
def single_emit(self, target_single_char, batch_size, mask, states=None): # Time as first dimension # only one batch embeddings = self.lookup.apply(target_single_char) if states is None: states = self.dgru.initial_states(batch_size) states_dict = {'states': states[0]} for i in range(1, self.dgru_depth): states_dict['states' + RECURRENTSTACK_SEPARATOR + str(i)] = states[i] gru_out = self.dgru.apply(**merge(self.gru_fork.apply(embeddings, as_dict=True), states_dict, {'mask': mask, 'iterate': False})) return gru_out
def __init__(self, vocab_size, embedding_dim, igru_state_dim, igru_depth, trg_dgru_depth, emitter, feedback_brick, merge=None, merge_prototype=None, post_merge=None, **kwargs): merged_dim = igru_state_dim if not merge: merge = Merge(input_names=kwargs['source_names'], prototype=merge_prototype) if not post_merge: post_merge = Bias(dim=merged_dim) # for compatible if igru_depth == 1: self.igru = IGRU(dim=igru_state_dim) else: self.igru = RecurrentStack([IGRU(dim=igru_state_dim, name='igru')] + [UpperIGRU(dim=igru_state_dim, activation=Tanh(), name='upper_igru' + str(i)) for i in range(1, igru_depth)], skip_connections=True) self.embedding_dim = embedding_dim self.emitter = emitter self.feedback_brick = feedback_brick self.merge = merge self.post_merge = post_merge self.merged_dim = merged_dim self.igru_depth = igru_depth self.trg_dgru_depth = trg_dgru_depth self.lookup = LookupTable(name='embeddings') self.vocab_size = vocab_size self.igru_state_dim = igru_state_dim self.gru_to_softmax = Linear(input_dim=igru_state_dim, output_dim=vocab_size) self.gru_fork = Fork([name for name in self.igru.apply.sequences if name != 'mask' and name != 'input_states'], prototype=Linear(), name='gru_fork') children = [self.emitter, self.feedback_brick, self.merge, self.post_merge, self.igru, self.lookup, self.gru_to_softmax, self.gru_fork] kwargs.setdefault('children', []).extend(children) super(Interpolator, self).__init__(**kwargs)
def _push_allocation_config(self): self.lookup.length = self.vocab_size self.lookup.dim = self.embedding_dim self.emitter.readout_dim = self.get_dim('readouts') self.merge.input_names = self.source_names self.merge.input_dims = self.source_dims self.merge.output_dim = self.merged_dim self.post_merge.input_dim = self.merged_dim self.post_merge.output_dim = self.igru_state_dim self.gru_fork.input_dim = self.embedding_dim self.gru_fork.output_dims = [self.igru.get_dim(name) for name in self.gru_fork.output_names]
def readout(self, **kwargs): merged = self.merge.apply(**{name: kwargs[name] for name in self.merge.input_names}) merged = self.post_merge.apply(merged) return merged
def readout_gru(self, target_prev_char_seq, target_prev_char_aux, input_states): embeddings = self.lookup.apply(target_prev_char_seq) gru_out = self.igru.apply( **merge(self.gru_fork.apply(embeddings, as_dict=True), {'mask': target_prev_char_aux, 'input_states': input_states})) if self.igru_depth > 1: gru_out = gru_out[-1] readout_chars = self.gru_to_softmax.apply(gru_out) return readout_chars
def __call(self, a, kw): sub_a, rest = self.__substitute__(self.__args, List.wrap(a)) sub_kw = merge(self.__kwargs, kw) return self.__func(*sub_a, **sub_kw), rest
def artifact_record(**kargs): artifact_props = t.merge({k: None for k in pc.artifact_properties}, _artifact_record_st.example(), {'inputs': {'varargs':[1,2,3], 'kargs': {}}, 'fn_module': 'foo', 'fn_name': 'bar', 'value': 55, 'name': 'bar', 'version': 0, 'serializer': 'joblib', 'run_info': pc.run_info()}, kargs) return pc.ArtifactRecord(**artifact_props)
def __repr__(self): return "lazy_dict({})".format( t.merge(t.valmap(lambda _: "...", self.thunks), self.realized))
def lazy_proxy_dict(artifacts_or_ids, group_artifacts_of_same_name=False): """ Takes a list of artifacts or artifact ids and returns a dictionary whose keys are the names of the artifacts. The values will be lazily loaded into proxies as requested. Parameters ---------- artifacts_or_ids : collection of artifacts or artifact ids (strings) group_artifacts_of_same_name: bool (default: False) If set to True then artifacts of the same name will be grouped together in one list. When set to False an exception will be raised """ if isinstance(artifacts_or_ids, dict): artifacts = t.valmap(coerce_to_artifact, artifacts_or_ids) lambdas = {name: (lambda a: lambda: a.proxy())(a) for name, a in artifacts.items()} return lazy_dict(lambdas) # else we have a collection artifacts = coerce_to_artifacts(artifacts_or_ids) by_name = t.groupby(lambda a: a.name, artifacts) singles = t.valfilter(lambda l: len(l) == 1, by_name) multi = t.valfilter(lambda l: len(l) > 1, by_name) lambdas = {name: (lambda a: lambda: a.proxy())(a[0]) for name, a in singles.items()} if group_artifacts_of_same_name and len(multi) > 0: lambdas = t.merge(lambdas, {name: (lambda artifacts: (lambda: [a.proxy() for a in artifacts]))(artifacts) for name, artifacts in multi.items()}) if not group_artifacts_of_same_name and len(multi) > 0: raise ValueError("""Only artifacts with distinct names can be used in a lazy_proxy_dict. Offending names: {} Use the option `group_artifacts_of_same_name=True` if you want a list of proxies to be returned under the respective key. """.format({n: len(a) for n, a in multi.items()})) return lazy_dict(lambdas)
def register_custom_objects(mapping, merge=False): if merge: res = t.merge(REGISTERED_CUSTOM_OBJECTS, mapping) else: res = mapping REGISTERED_CUSTOM_OBJECTS = res #TODO: move custom_objects into the attrs