Python mxnet 模块,gpu() 实例源码

我们从Python开源项目中,提取了以下50个代码示例,用于说明如何使用mxnet.gpu()

项目:DQN-FlappyBird-mxnet    作者:foolyc    | 项目源码 | 文件源码
def __init__(self):
        self.replayMemory = deque()
        self.timestep = 0
        if FLG_GPU:
            self.ctx = mx.gpu()
        else:
            self.ctx = mx.cpu()
        if args.mode == 'train':
            self.q_net = mx.mod.Module(symbol=self.createNet(1), data_names=['frame', 'act_mul'], label_names=['target', ], context=self.ctx)
            self.q_net.bind(data_shapes=[('frame', (BATCH, FRAME, HEIGHT, WIDTH)), ('act_mul', (BATCH, ACTIONS))], label_shapes=[('target', (BATCH,))], for_training=True)
            self.q_net.init_params(initializer=mx.init.Xavier(factor_type="in", magnitude=2.34))
            self.q_net.init_optimizer(optimizer='adam', optimizer_params={'learning_rate': 0.0002, 'wd': 0.0, 'beta1': 0.5})
            if args.pretrain:
                self.q_net.load_params(args.pretrain)
                print "load pretrained file......"

        self.tg_net = mx.mod.Module(symbol=self.createNet(), data_names=['frame',], label_names=[], context=self.ctx)
        self.tg_net.bind(data_shapes=[('frame', (1, FRAME, HEIGHT, WIDTH))], for_training=False)
        self.tg_net.init_params(initializer=mx.init.Xavier(factor_type='in', magnitude=2.34))
        if args.pretrain:
            self.tg_net.load_params(args.pretrain)
            print "load pretrained file......"
项目:mx-rfcn    作者:giorking    | 项目源码 | 文件源码
def parse_args():
    parser = argparse.ArgumentParser(description='Test a Region Proposal Network')
    parser.add_argument('--image_set', dest='image_set', help='can be trainval or train',
                        default='trainval', type=str)
    parser.add_argument('--year', dest='year', help='can be 2007, 2010, 2012',
                        default='2007', type=str)
    parser.add_argument('--root_path', dest='root_path', help='output data folder',
                        default=os.path.join(os.getcwd(), 'data'), type=str)
    parser.add_argument('--devkit_path', dest='devkit_path', help='VOCdevkit path',
                        default=os.path.join(os.getcwd(), 'data', 'VOCdevkit'), type=str)
    parser.add_argument('--prefix', dest='prefix', help='model to test with', type=str)
    parser.add_argument('--epoch', dest='epoch', help='model to test with',
                        default=8, type=int)
    parser.add_argument('--gpu', dest='gpu_id', help='GPU device to train with',
                        default=0, type=int)
    parser.add_argument('--vis', dest='vis', help='turn on visualization', action='store_true')
    args = parser.parse_args()
    return args
项目:mx-rfcn    作者:giorking    | 项目源码 | 文件源码
def parse_args():
    parser = argparse.ArgumentParser(description='Test a Fast R-CNN network')
    parser.add_argument('--image_set', dest='image_set', help='can be test',
                        default='test', type=str)
    parser.add_argument('--year', dest='year', help='can be 2007, 2010, 2012',
                        default='2007', type=str)
    parser.add_argument('--root_path', dest='root_path', help='output data folder',
                        default=os.path.join(os.getcwd(), 'data'), type=str)
    parser.add_argument('--devkit_path', dest='devkit_path', help='VOCdevkit path',
                        default=os.path.join(os.getcwd(), 'data', 'VOCdevkit'), type=str)
    parser.add_argument('--prefix', dest='prefix', help='model to test with', type=str)
    parser.add_argument('--epoch', dest='epoch', help='model to test with',
                        default=8, type=int)
    parser.add_argument('--gpu', dest='gpu_id', help='GPU device to test with',
                        default=0, type=int)
    parser.add_argument('--vis', dest='vis', help='turn on visualization', action='store_true')
    parser.add_argument('--has_rpn', dest='has_rpn', help='generate proposals on the fly',
                        action='store_true')
    parser.add_argument('--proposal', dest='proposal', help='can be ss for selective search or rpn',
                        default='rpn', type=str)
    args = parser.parse_args()
    return args
项目:srep    作者:Answeror    | 项目源码 | 文件源码
def _crossval_predict(self, **kargs):
    proba = kargs.pop('proba', False)
    fold = int(kargs.pop('fold'))
    Mod = kargs.pop('Mod')
    Mod = deepcopy(Mod)
    Mod.update(params=self.format_params(Mod['params'], fold))
    context = Mod.pop('context', [mx.gpu(0)])
    #  import pickle
    #  d = kargs.copy()
    #  d.update(Mod=Mod, fold=fold)
    #  print(pickle.dumps(d))

    #  Ensure load from disk.
    #  Otherwise following cached methods like vote will have two caches,
    #  one for the first computation,
    #  and the other for the cached one.
    func = _crossval_predict_aux if not proba else _crossval_predict_proba_aux
    return func.call_and_shelve(self, Mod=Mod, fold=fold, context=context, **kargs).get()
项目:zhihu_cup    作者:Godricly    | 项目源码 | 文件源码
def sym_gen_word(bucket_key):
    key = bucket_key.split(',')
    tw_length = int(key[0])
    cw_length = int(key[1])
    tw_data = mx.sym.Variable('tw_array')
    cw_data = mx.sym.Variable('cw_array')
    label   = mx.sym.Variable('label')
    tw_slices = list(mx.symbol.SliceChannel(data=tw_data, axis=1, num_outputs=tw_length, squeeze_axis=True, name='tw_slice'))
    cw_slices = list(mx.symbol.SliceChannel(data=cw_data, axis=1, num_outputs=cw_length, squeeze_axis=True, name='cw_slice'))
    tw_concat, _ = tw_cell.unroll(tw_length, inputs = tw_slices, merge_outputs=True, layout='TNC')
    cw_concat, _ = cw_cell.unroll(cw_length, inputs = cw_slices, merge_outputs=True, layout='TNC')
    tw_concat = mx.sym.transpose(tw_concat, (1, 2, 0))
    cw_concat = mx.sym.transpose(cw_concat, (1, 2, 0))
    tw_concat = mx.sym.Pooling(tw_concat, kernel=(1,), global_pool = True, pool_type='max')
    cw_concat = mx.sym.Pooling(cw_concat, kernel=(1,), global_pool = True, pool_type='max')
    feature = mx.sym.Concat(*[tw_concat, cw_concat], name= 'concat')
    feature = mx.sym.Dropout(feature, p=0.5)
    feature = fc_module(feature, 'fc1', num_hidden=1024)
    feature = fc_module(feature, 'fc2', num_hidden=1024)
    feature = fc_module(feature, 'feature', num_hidden=2000)
    loss = mx.sym.LinearRegressionOutput(feature, label=label, name='regression')
    return loss, data_name, label_name

#mod = mx.module.BucketingModule(sym_gen_word, default_bucket_key=ziter.max_bucket_key,context=mx.gpu(1),data_names=data_name, label_names=label_name)
项目:zhihu_cup    作者:Godricly    | 项目源码 | 文件源码
def sym_gen_word(bucket_key):
    key = bucket_key.split(',')
    tw_length = int(key[0])
    cw_length = int(key[1])
    tw_data = mx.sym.Variable('tw_array')
    cw_data = mx.sym.Variable('cw_array')
    label   = mx.sym.Variable('label')
    tw_slices = list(mx.symbol.SliceChannel(data=tw_data, axis=1, num_outputs=tw_length, squeeze_axis=True, name='tw_slice'))
    cw_slices = list(mx.symbol.SliceChannel(data=cw_data, axis=1, num_outputs=cw_length, squeeze_axis=True, name='cw_slice'))
    tw_concat, _ = tw_cell.unroll(tw_length, inputs = tw_slices, merge_outputs=True, layout='TNC')
    cw_concat, _ = cw_cell.unroll(cw_length, inputs = cw_slices, merge_outputs=True, layout='TNC')
    tw_concat = mx.sym.transpose(tw_concat, (1, 2, 0)) 
    cw_concat = mx.sym.transpose(cw_concat, (1, 2, 0)) 
    tw_concat = mx.sym.Pooling(tw_concat, kernel=(1,), global_pool = True, pool_type='max')
    cw_concat = mx.sym.Pooling(cw_concat, kernel=(1,), global_pool = True, pool_type='max')
    feature = mx.sym.Concat(*[tw_concat, cw_concat], name= 'concat')
    feature = fc_module(feature, 'fc1', num_hidden=1024)
    feature = fc_module(feature, 'fc2', num_hidden=1024)
    feature = mx.sym.Dropout(feature, p=0.5)
    feature = fc_module(feature, 'feature', num_hidden=2000)
    loss = mx.sym.LogisticRegressionOutput(feature, label=label, name='regression')
    return loss, data_name, label_name

#mod = mx.module.BucketingModule(sym_gen_word, default_bucket_key=ziter.max_bucket_key,context=mx.gpu(1),data_names=data_name, label_names=label_name)
项目:odnl    作者:lilhope    | 项目源码 | 文件源码
def parse_args():
    parser = argparse.ArgumentParser(description='Test a Region Proposal Network')
    # general
    parser.add_argument('--network', help='network name', default=default.network, type=str)
    parser.add_argument('--dataset', help='dataset name', default=default.dataset, type=str)
    args, rest = parser.parse_known_args()
    generate_config(args.network, args.dataset)
    parser.add_argument('--image_set', help='image_set name', default=default.test_image_set, type=str)
    parser.add_argument('--root_path', help='output data folder', default=default.root_path, type=str)
    parser.add_argument('--dataset_path', help='dataset path', default=default.dataset_path, type=str)
    # testing
    parser.add_argument('--prefix', help='model to test with', default=default.rpn_prefix, type=str)
    parser.add_argument('--epoch', help='model to test with', default=default.rpn_epoch, type=int)
    # rpn
    parser.add_argument('--gpu', help='GPU device to test with', default=0, type=int)
    parser.add_argument('--vis', help='turn on visualization', action='store_true')
    parser.add_argument('--thresh', help='rpn proposal threshold', default=0, type=float)
    parser.add_argument('--shuffle', help='shuffle data on visualization', action='store_true')
    args = parser.parse_args()
    return args
项目:odnl    作者:lilhope    | 项目源码 | 文件源码
def parse_args():
    parser = argparse.ArgumentParser(description='Test a Fast R-CNN network')
    # general
    parser.add_argument('--network', help='network name', default=default.network, type=str)
    parser.add_argument('--dataset', help='dataset name', default=default.dataset, type=str)
    args, rest = parser.parse_known_args()
    generate_config(args.network, args.dataset)
    parser.add_argument('--image_set', help='image_set name', default=default.test_image_set, type=str)
    parser.add_argument('--root_path', help='output data folder', default=default.root_path, type=str)
    parser.add_argument('--dataset_path', help='dataset path', default=default.dataset_path, type=str)
    # testing
    parser.add_argument('--prefix', help='model to test with', default=default.rcnn_prefix, type=str)
    parser.add_argument('--epoch', help='model to test with', default=default.rcnn_epoch, type=int)
    parser.add_argument('--gpu', help='GPU device to test with', default=0, type=int)
    # rcnn
    parser.add_argument('--vis', help='turn on visualization', action='store_true')
    parser.add_argument('--thresh', help='valid detection threshold', default=1e-3, type=float)
    parser.add_argument('--shuffle', help='shuffle data on visualization', action='store_true')
    parser.add_argument('--has_rpn', help='generate proposals on the fly', action='store_true')
    parser.add_argument('--proposal', help='can be ss for selective search or rpn', default='rpn', type=str)
    args = parser.parse_args()
    return args
项目:odnl    作者:lilhope    | 项目源码 | 文件源码
def parse_args():
    parser = argparse.ArgumentParser(description='Test a Faster R-CNN network')
    # general
    parser.add_argument('--network', help='network name', default=default.network, type=str)
    parser.add_argument('--dataset', help='dataset name', default=default.dataset, type=str)
    args, rest = parser.parse_known_args()
    generate_config(args.network, args.dataset)
    args, rest = parser.parse_known_args()
    data_root = os.path.join(os.getcwd(),default.root_path)

    parser.add_argument('--root_path', help='output data folder', default=data_root, type=str)
    parser.add_argument('--subset',help='subset of dataset,only for refer dataset',default=default.subset,type=str)
    parser.add_argument('--split',help='split of dataset,only for refer dataset',default=default.split,type=str)
    # testing
    parser.add_argument('--prefix', help='model to test with', default=default.e2e_prefix, type=str)
    parser.add_argument('--epoch', help='model to test with', default=default.e2e_epoch, type=int)
    parser.add_argument('--gpu', help='GPU device to test with', default=0, type=int)
    # rcnn
    parser.add_argument('--vis', help='turn on visualization', action='store_true')
    parser.add_argument('--thresh', help='valid detection threshold', default=1e-1, type=float)
    parser.add_argument('--shuffle', help='shuffle data on visualization', action='store_true')
    parser.add_argument('--has_rpn', help='generate proposals on the fly', action='store_true', default=True)
    parser.add_argument('--proposal', help='can be ss for selective search or rpn', default='rpn', type=str)
    args = parser.parse_args()
    return args
项目:mxnet-fast-neural-style    作者:SineYuan    | 项目源码 | 文件源码
def build_parser():
    parser = ArgumentParser()
    parser.add_argument('--checkpoint', type=str,
                        dest='checkpoint',
                        help='checkpoint params file which generated in training',
                        metavar='CHECKPOINT', required=True)

    parser.add_argument('--in-path', type=str,
                        dest='in_path', help='dir or file to transform',
                        metavar='IN_PATH', required=True)

    parser.add_argument('--out-path', type=str, dest='out_path',
                        help='destination dir of transformed file or files',
                        metavar='OUT_PATH',
                        required=True)

    parser.add_argument('--resize', type=int, nargs=2, dest='resize',
                        help='resize the input image files, usage: --resize=300 400',
                        )

    parser.add_argument('--gpu', type=int, default=GPU,
                        help='which gpu card to use, -1 means using cpu (default %(default)s)')

    return parser
项目:sockeye    作者:awslabs    | 项目源码 | 文件源码
def _setup_context(args, exit_stack):
    if args.use_cpu:
        context = mx.cpu()
    else:
        num_gpus = get_num_gpus()
        check_condition(num_gpus >= 1,
                        "No GPUs found, consider running on the CPU with --use-cpu "
                        "(note: check depends on nvidia-smi and this could also mean that the nvidia-smi "
                        "binary isn't on the path).")
        check_condition(len(args.device_ids) == 1, "cannot run on multiple devices for now")
        gpu_id = args.device_ids[0]
        if args.disable_device_locking:
            if gpu_id < 0:
                # without locking and a negative device id we just take the first device
                gpu_id = 0
        else:
            gpu_ids = exit_stack.enter_context(acquire_gpus([gpu_id], lock_dir=args.lock_dir))
            gpu_id = gpu_ids[0]

        context = mx.gpu(gpu_id)
    return context
项目:Parallel-RFCN    作者:TrafficObjectDetection    | 项目源码 | 文件源码
def parse_args():
    parser = argparse.ArgumentParser(description='Test a Faster R-CNN network')
    # general
    parser.add_argument('--network', help='network name', default=default.network, type=str)
    parser.add_argument('--dataset', help='dataset name', default=default.dataset, type=str)
    args, rest = parser.parse_known_args()
    generate_config(args.network, args.dataset)
    parser.add_argument('--image_set', help='image_set name', default=default.test_image_set, type=str)
    parser.add_argument('--root_path', help='output data folder', default=default.root_path, type=str)
    parser.add_argument('--dataset_path', help='dataset path', default=default.dataset_path, type=str)
    # testing
    parser.add_argument('--prefix', help='model to test with', default=default.e2e_prefix, type=str)
    parser.add_argument('--epoch', help='model to test with', default=default.e2e_epoch, type=int)
    parser.add_argument('--gpu', help='GPU device to test with', default=0, type=int)
    # rcnn
    parser.add_argument('--vis', help='turn on visualization', action='store_true')
    parser.add_argument('--thresh', help='valid detection threshold', default=1e-3, type=float)
    parser.add_argument('--shuffle', help='shuffle data on visualization', action='store_true')
    parser.add_argument('--has_rpn', help='generate proposals on the fly', action='store_true', default=True)
    parser.add_argument('--proposal', help='can be ss for selective search or rpn', default='rpn', type=str)
    args = parser.parse_args()
    return args
项目:mlAlgorithms    作者:gu-yan    | 项目源码 | 文件源码
def __init__(self,
                 modelprefix,
                 imagepath,
                 inputshape,
                 labelpath,
                 epoch=0,
                 format='NCHW'):
        self.modelprefix = modelprefix
        self.imagepath = imagepath
        self.labelpath = labelpath
        self.inputshape = inputshape
        self.epoch = epoch
        self.format = format

        with open(labelpath, 'r') as fo:
            self.labels = [l.rstrip() for l in fo]

        sym, arg_params, aux_params = mx.model.load_checkpoint(self.modelprefix, self.epoch)
        self.mod = mx.mod.Module(symbol=sym, context=mx.gpu(), label_names=None)
        self.mod.bind(for_training=False,
                      data_shapes=[('data', (1, self.inputshape[0], self.inputshape[1], self.inputshape[2]))],
                      label_shapes=self.mod._label_shapes)
        self.mod.set_params(arg_params, aux_params, allow_missing=True)
项目:mxnet_tk1    作者:starimpact    | 项目源码 | 文件源码
def test_group_kvstore(kv_type):
    print(kv_type)
    kv = mx.kv.create(kv_type)
    kv.set_optimizer(mx.optimizer.create('test', lr))
    kv.init(keys, [mx.nd.zeros(s) for s in shapes])
    res = [np.zeros(s) for s in shapes]
    out = [[mx.nd.zeros(s, mx.gpu(g)) for g in range(nworker)] for s in shapes]
    for i in range(nrepeat):
        kv.push(keys, [[
            mx.nd.array(data[i][j][g], mx.gpu(g)) for g in range(nworker)]
                       for j in range(len(keys))])

        kv.pull(keys, out=out)
        res = [a + b * lr for a, b in zip(res, [sum(d) for d in data[i]])]
        for a, b in zip(res, out):
            err = [np.sum(np.abs(o.asnumpy() - a)) for o in b]
            err = sum(err) / np.sum(np.abs(a))
            assert(err < 1e-6), (err, a.shape)
项目:mxnet_tk1    作者:starimpact    | 项目源码 | 文件源码
def parse_args():
    parser = argparse.ArgumentParser(description='Test a Fast R-CNN network')
    parser.add_argument('--image_set', dest='image_set', help='can be test',
                        default='test', type=str)
    parser.add_argument('--year', dest='year', help='can be 2007, 2010, 2012',
                        default='2007', type=str)
    parser.add_argument('--root_path', dest='root_path', help='output data folder',
                        default=os.path.join(os.getcwd(), 'data'), type=str)
    parser.add_argument('--devkit_path', dest='devkit_path', help='VOCdevkit path',
                        default=os.path.join(os.getcwd(), 'data', 'VOCdevkit'), type=str)
    parser.add_argument('--prefix', dest='prefix', help='new model prefix',
                        default=os.path.join(os.getcwd(), 'model', 'frcnn'), type=str)
    parser.add_argument('--epoch', dest='epoch', help='epoch of pretrained model',
                        default=9, type=int)
    parser.add_argument('--gpu', dest='gpu_id', help='GPU device to test with',
                        default=0, type=int)
    args = parser.parse_args()
    return args
项目:additions_mxnet    作者:eldercrow    | 项目源码 | 文件源码
def parse_args():
    parser = argparse.ArgumentParser(description='Test a Region Proposal Network')
    # general
    parser.add_argument('--network', help='network name', default=default.network, type=str)
    parser.add_argument('--dataset', help='dataset name', default=default.dataset, type=str)
    args, rest = parser.parse_known_args()
    generate_config(args.network, args.dataset)
    parser.add_argument('--image_set', help='image_set name', default=default.test_image_set, type=str)
    parser.add_argument('--root_path', help='output data folder', default=default.root_path, type=str)
    parser.add_argument('--dataset_path', help='dataset path', default=default.dataset_path, type=str)
    # testing
    parser.add_argument('--prefix', help='model to test with', default=default.rpn_prefix, type=str)
    parser.add_argument('--epoch', help='model to test with', default=default.rpn_epoch, type=int)
    # rpn
    parser.add_argument('--gpu', help='GPU device to test with', default=0, type=int)
    parser.add_argument('--vis', help='turn on visualization', action='store_true')
    parser.add_argument('--thresh', help='rpn proposal threshold', default=0, type=float)
    parser.add_argument('--shuffle', help='shuffle data on visualization', action='store_true')
    args = parser.parse_args()
    return args
项目:additions_mxnet    作者:eldercrow    | 项目源码 | 文件源码
def parse_args():
    parser = argparse.ArgumentParser(description='Test a Fast R-CNN network')
    # general
    parser.add_argument('--network', help='network name', default=default.network, type=str)
    parser.add_argument('--dataset', help='dataset name', default=default.dataset, type=str)
    args, rest = parser.parse_known_args()
    generate_config(args.network, args.dataset)
    parser.add_argument('--image_set', help='image_set name', default=default.test_image_set, type=str)
    parser.add_argument('--root_path', help='output data folder', default=default.root_path, type=str)
    parser.add_argument('--dataset_path', help='dataset path', default=default.dataset_path, type=str)
    # testing
    parser.add_argument('--prefix', help='model to test with', default=default.rcnn_prefix, type=str)
    parser.add_argument('--epoch', help='model to test with', default=default.rcnn_epoch, type=int)
    parser.add_argument('--gpu', help='GPU device to test with', default=0, type=int)
    # rcnn
    parser.add_argument('--vis', help='turn on visualization', action='store_true')
    parser.add_argument('--thresh', help='valid detection threshold', default=1e-3, type=float)
    parser.add_argument('--shuffle', help='shuffle data on visualization', action='store_true')
    parser.add_argument('--has_rpn', help='generate proposals on the fly', action='store_true')
    parser.add_argument('--proposal', help='can be ss for selective search or rpn', default='rpn', type=str)
    args = parser.parse_args()
    return args
项目:additions_mxnet    作者:eldercrow    | 项目源码 | 文件源码
def parse_args():
    parser = argparse.ArgumentParser(description='Test a Faster R-CNN network')
    # general
    parser.add_argument('--network', help='network name', default=default.network, type=str)
    parser.add_argument('--dataset', help='dataset name', default=default.dataset, type=str)
    args, rest = parser.parse_known_args()
    generate_config(args.network, args.dataset)
    parser.add_argument('--image_set', help='image_set name', default=default.test_image_set, type=str)
    parser.add_argument('--root_path', help='output data folder', default=default.root_path, type=str)
    parser.add_argument('--dataset_path', help='dataset path', default=default.dataset_path, type=str)
    # testing
    parser.add_argument('--prefix', help='model to test with', default=default.e2e_prefix, type=str)
    parser.add_argument('--epoch', help='model to test with', default=default.e2e_epoch, type=int)
    parser.add_argument('--gpu', help='GPU device to test with', default=0, type=int)
    # rcnn
    parser.add_argument('--vis', help='turn on visualization', action='store_true')
    parser.add_argument('--thresh', help='valid detection threshold', default=1e-3, type=float)
    parser.add_argument('--shuffle', help='shuffle data on visualization', action='store_true')
    parser.add_argument('--has_rpn', help='generate proposals on the fly', action='store_true', default=True)
    parser.add_argument('--proposal', help='can be ss for selective search or rpn', default='rpn', type=str)
    args = parser.parse_args()
    return args
项目:ademxapp    作者:itijyou    | 项目源码 | 文件源码
def _get_module(args, margs, dargs, net=None):
    if net is None:
        # the following lines show how to create symbols for our networks
        if model_specs['net_type'] == 'rna':
            from util.symbol.symbol import cfg as symcfg
            symcfg['lr_type'] = 'alex'
            symcfg['workspace'] = dargs.mx_workspace
            symcfg['bn_use_global_stats'] = True
            if model_specs['net_name'] == 'a1':
                from util.symbol.resnet_v2 import fcrna_model_a1
                net = fcrna_model_a1(margs.classes, margs.feat_stride, bootstrapping=True)
        if net is None:
            raise NotImplementedError('Unknown network: {}'.format(vars(margs)))
    contexts = [mx.gpu(int(_)) for _ in args.gpus.split(',')]
    mod = mx.mod.Module(net, context=contexts)
    return mod
项目:mx-rcnn    作者:precedenceguo    | 项目源码 | 文件源码
def parse_args():
    parser = argparse.ArgumentParser(description='Test a Region Proposal Network')
    # general
    parser.add_argument('--network', help='network name', default=default.network, type=str)
    parser.add_argument('--dataset', help='dataset name', default=default.dataset, type=str)
    args, rest = parser.parse_known_args()
    generate_config(args.network, args.dataset)
    parser.add_argument('--image_set', help='image_set name', default=default.test_image_set, type=str)
    parser.add_argument('--root_path', help='output data folder', default=default.root_path, type=str)
    parser.add_argument('--dataset_path', help='dataset path', default=default.dataset_path, type=str)
    # testing
    parser.add_argument('--prefix', help='model to test with', default=default.rpn_prefix, type=str)
    parser.add_argument('--epoch', help='model to test with', default=default.rpn_epoch, type=int)
    # rpn
    parser.add_argument('--gpu', help='GPU device to test with', default=0, type=int)
    parser.add_argument('--vis', help='turn on visualization', action='store_true')
    parser.add_argument('--thresh', help='rpn proposal threshold', default=0, type=float)
    parser.add_argument('--shuffle', help='shuffle data on visualization', action='store_true')
    args = parser.parse_args()
    return args
项目:mx-rcnn    作者:precedenceguo    | 项目源码 | 文件源码
def parse_args():
    parser = argparse.ArgumentParser(description='Test a Fast R-CNN network')
    # general
    parser.add_argument('--network', help='network name', default=default.network, type=str)
    parser.add_argument('--dataset', help='dataset name', default=default.dataset, type=str)
    args, rest = parser.parse_known_args()
    generate_config(args.network, args.dataset)
    parser.add_argument('--image_set', help='image_set name', default=default.test_image_set, type=str)
    parser.add_argument('--root_path', help='output data folder', default=default.root_path, type=str)
    parser.add_argument('--dataset_path', help='dataset path', default=default.dataset_path, type=str)
    # testing
    parser.add_argument('--prefix', help='model to test with', default=default.rcnn_prefix, type=str)
    parser.add_argument('--epoch', help='model to test with', default=default.rcnn_epoch, type=int)
    parser.add_argument('--gpu', help='GPU device to test with', default=0, type=int)
    # rcnn
    parser.add_argument('--vis', help='turn on visualization', action='store_true')
    parser.add_argument('--thresh', help='valid detection threshold', default=1e-3, type=float)
    parser.add_argument('--shuffle', help='shuffle data on visualization', action='store_true')
    parser.add_argument('--has_rpn', help='generate proposals on the fly', action='store_true')
    parser.add_argument('--proposal', help='can be ss for selective search or rpn', default='rpn', type=str)
    args = parser.parse_args()
    return args
项目:mx-rcnn    作者:precedenceguo    | 项目源码 | 文件源码
def parse_args():
    parser = argparse.ArgumentParser(description='Test a Faster R-CNN network')
    # general
    parser.add_argument('--network', help='network name', default=default.network, type=str)
    parser.add_argument('--dataset', help='dataset name', default=default.dataset, type=str)
    args, rest = parser.parse_known_args()
    generate_config(args.network, args.dataset)
    parser.add_argument('--image_set', help='image_set name', default=default.test_image_set, type=str)
    parser.add_argument('--root_path', help='output data folder', default=default.root_path, type=str)
    parser.add_argument('--dataset_path', help='dataset path', default=default.dataset_path, type=str)
    # testing
    parser.add_argument('--prefix', help='model to test with', default=default.e2e_prefix, type=str)
    parser.add_argument('--epoch', help='model to test with', default=default.e2e_epoch, type=int)
    parser.add_argument('--gpu', help='GPU device to test with', default=0, type=int)
    # rcnn
    parser.add_argument('--vis', help='turn on visualization', action='store_true')
    parser.add_argument('--thresh', help='valid detection threshold', default=1e-3, type=float)
    parser.add_argument('--shuffle', help='shuffle data on visualization', action='store_true')
    parser.add_argument('--has_rpn', help='generate proposals on the fly', action='store_true', default=True)
    parser.add_argument('--proposal', help='can be ss for selective search or rpn', default='rpn', type=str)
    args = parser.parse_args()
    return args
项目:sia-cog    作者:deepakkumar1984    | 项目源码 | 文件源码
def parse_args():
    parser = argparse.ArgumentParser(description='Test a Region Proposal Network')
    # general
    parser.add_argument('--network', help='network name', default=default.network, type=str)
    parser.add_argument('--dataset', help='dataset name', default=default.dataset, type=str)
    args, rest = parser.parse_known_args()
    generate_config(args.network, args.dataset)
    parser.add_argument('--image_set', help='image_set name', default=default.test_image_set, type=str)
    parser.add_argument('--root_path', help='output data folder', default=default.root_path, type=str)
    parser.add_argument('--dataset_path', help='dataset path', default=default.dataset_path, type=str)
    # testing
    parser.add_argument('--prefix', help='model to test with', default=default.rpn_prefix, type=str)
    parser.add_argument('--epoch', help='model to test with', default=default.rpn_epoch, type=int)
    # rpn
    parser.add_argument('--gpu', help='GPU device to test with', default=0, type=int)
    parser.add_argument('--vis', help='turn on visualization', action='store_true')
    parser.add_argument('--thresh', help='rpn proposal threshold', default=0, type=float)
    parser.add_argument('--shuffle', help='shuffle data on visualization', action='store_true')
    args = parser.parse_args()
    return args
项目:sia-cog    作者:deepakkumar1984    | 项目源码 | 文件源码
def parse_args():
    parser = argparse.ArgumentParser(description='Test a Fast R-CNN network')
    # general
    parser.add_argument('--network', help='network name', default=default.network, type=str)
    parser.add_argument('--dataset', help='dataset name', default=default.dataset, type=str)
    args, rest = parser.parse_known_args()
    generate_config(args.network, args.dataset)
    parser.add_argument('--image_set', help='image_set name', default=default.test_image_set, type=str)
    parser.add_argument('--root_path', help='output data folder', default=default.root_path, type=str)
    parser.add_argument('--dataset_path', help='dataset path', default=default.dataset_path, type=str)
    # testing
    parser.add_argument('--prefix', help='model to test with', default=default.rcnn_prefix, type=str)
    parser.add_argument('--epoch', help='model to test with', default=default.rcnn_epoch, type=int)
    parser.add_argument('--gpu', help='GPU device to test with', default=0, type=int)
    # rcnn
    parser.add_argument('--vis', help='turn on visualization', action='store_true')
    parser.add_argument('--thresh', help='valid detection threshold', default=1e-3, type=float)
    parser.add_argument('--shuffle', help='shuffle data on visualization', action='store_true')
    parser.add_argument('--has_rpn', help='generate proposals on the fly', action='store_true')
    parser.add_argument('--proposal', help='can be ss for selective search or rpn', default='rpn', type=str)
    args = parser.parse_args()
    return args
项目:mxnet-gan    作者:tqchen    | 项目源码 | 文件源码
def test_log_sum_exp():
    xpu = mx.gpu()
    shape = (2, 2, 100)
    axis = 2
    keepdims = True
    X = mx.sym.Variable('X')
    Y = log_sum_exp(X, axis=axis, keepdims=keepdims)
    x = mx.nd.array(np.random.normal(size=shape))
    x[:] = 1
    xgrad = mx.nd.empty(x.shape)
    exec1 = Y.bind(xpu, args = [x], args_grad = {'X': xgrad})
    exec1.forward()
    y = exec1.outputs[0]
    np.testing.assert_allclose(
        y.asnumpy(),
        np_log_sum_exp(x.asnumpy(), axis=axis, keepdims=keepdims))
    y[:] = 1
    exec1.backward([y])
    np.testing.assert_allclose(
        xgrad.asnumpy(),
        np_softmax(x.asnumpy(), axis=axis) * y.asnumpy())
项目:mxnet-gan    作者:tqchen    | 项目源码 | 文件源码
def test_constant():
    xpu = mx.gpu()
    shape = (2, 2, 100)
    x = mx.nd.ones(shape, ctx=xpu)
    y = mx.nd.ones(shape, ctx=xpu)
    gy = mx.nd.zeros(shape, ctx=xpu)
    X = constant(x) + mx.sym.Variable('Y')
    xexec = X.bind(xpu,
                   {'Y': y},
                   {'Y': gy})
    xexec.forward()
    np.testing.assert_allclose(
        xexec.outputs[0].asnumpy(), (x + y).asnumpy())
    xexec.backward([y])
    np.testing.assert_allclose(
        gy.asnumpy(), y.asnumpy())
项目:ResNet    作者:tornadomeet    | 项目源码 | 文件源码
def main():
    synset = [l.strip() for l in open(args.synset).readlines()]
    img = cv2.cvtColor(cv2.imread(args.img), cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, (224, 224))  # resize to 224*224 to fit model
    img = np.swapaxes(img, 0, 2)
    img = np.swapaxes(img, 1, 2)  # change to (c, h,w) order
    img = img[np.newaxis, :]  # extend to (n, c, h, w)

    ctx = mx.gpu(args.gpu)
    sym, arg_params, aux_params = mx.model.load_checkpoint(args.prefix, args.epoch)
    arg_params, aux_params = ch_dev(arg_params, aux_params, ctx)
    arg_params["data"] = mx.nd.array(img, ctx)
    arg_params["softmax_label"] = mx.nd.empty((1,), ctx)
    exe = sym.bind(ctx, arg_params ,args_grad=None, grad_req="null", aux_states=aux_params)
    exe.forward(is_train=False)

    prob = np.squeeze(exe.outputs[0].asnumpy())
    pred = np.argsort(prob)[::-1]
    print("Top1 result is: ", synset[pred[0]])
    print("Top5 result is: ", [synset[pred[i]] for i in range(5)])
项目:mx-lsoftmax    作者:luoyetx    | 项目源码 | 文件源码
def train():
    ctx = mx.gpu(args.gpu) if args.gpu >=0 else mx.cpu()
    train = mx.io.MNISTIter(
                image='data/train-images-idx3-ubyte',
                label='data/train-labels-idx1-ubyte',
                input_shape=(1, 28, 28),
                mean_r=128,
                scale=1./128,
                batch_size=args.batch_size,
                shuffle=True)
    val = mx.io.MNISTIter(
                image='data/t10k-images-idx3-ubyte',
                label='data/t10k-labels-idx1-ubyte',
                input_shape=(1, 28, 28),
                mean_r=128,
                scale=1./128,
                batch_size=args.batch_size)
    symbol = get_symbol()
    mod = mx.mod.Module(
            symbol=symbol,
            context=ctx,
            data_names=('data',),
            label_names=('softmax_label',))
    num_examples = 60000
    epoch_size = int(num_examples / args.batch_size)
    optim_params = {
        'learning_rate': args.lr,
        'momentum': 0.9,
        'wd': 0.0005,
        'lr_scheduler': mx.lr_scheduler.FactorScheduler(step=10*epoch_size, factor=0.1),
    }
    mod.fit(train_data=train,
            eval_data=val,
            eval_metric=mx.metric.Accuracy(),
            initializer=mx.init.Xavier(),
            optimizer='sgd',
            optimizer_params=optim_params,
            num_epoch=args.num_epoch,
            batch_end_callback=mx.callback.Speedometer(args.batch_size, 50),
            epoch_end_callback=mx.callback.do_checkpoint(args.model_prefix))
项目:mx-lsoftmax    作者:luoyetx    | 项目源码 | 文件源码
def profile():
    ctx = mx.gpu(args.gpu) if args.gpu >=0 else mx.cpu()
    val = mx.io.MNISTIter(
            image='data/t10k-images-idx3-ubyte',
            label='data/t10k-labels-idx1-ubyte',
            input_shape=(1, 28, 28),
            mean_r=128,
            scale=1./128,
            batch_size=args.batch_size)
    symbol = get_symbol()
    mod = mx.mod.Module(
            symbol=symbol,
            context=ctx,
            data_names=('data',),
            label_names=('softmax_label',))
    mod.bind(data_shapes=val.provide_data, label_shapes=val.provide_label, for_training=True)
    mod.init_params(initializer=mx.init.Xavier())

    # run a while
    for nbatch, data_batch in enumerate(val):
        mod.forward_backward(data_batch)

    # profile
    mx.profiler.profiler_set_config(mode='symbolic', filename='profile.json')
    mx.profiler.profiler_set_state('run')
    val.reset()
    for nbatch, data_batch in enumerate(val):
        mod.forward_backward(data_batch)
    mx.profiler.profiler_set_state('stop')
项目:kaggle-dstl-satellite-imagery-feature-detection    作者:u1234x1234    | 项目源码 | 文件源码
def load_model(version, epoch, patch_size, batch_size=8, ctx=mx.gpu()):
    sym, arg, aux = mx.model.load_checkpoint('models/' + version, epoch)
    mod = mx.module.Module(sym, context=ctx)
    mod.bind(data_shapes=[('data', (batch_size, 20, patch_size, patch_size))],
             for_training=False)
    mod.set_params(arg, aux)
    return mod
项目:kaggle-dstl-satellite-imagery-feature-detection    作者:u1234x1234    | 项目源码 | 文件源码
def load_model(version, epoch, batch_size=8, ctx=mx.gpu()):
    sym, arg, aux = mx.model.load_checkpoint('models/' + version, epoch)
    mod = mx.module.Module(sym, context=ctx, data_names=list(d))
    mod.bind(data_shapes=list(d.iteritems()), for_training=False)
    mod.set_params(arg, aux)
    return mod
项目:kaggle-dstl-satellite-imagery-feature-detection    作者:u1234x1234    | 项目源码 | 文件源码
def predict(d):
    '''get predicted raster
    '''
    image_id, i = d
    mod = load_model(version, epoch, batch_size, mx.gpu(0))

    image_data = get_data(image_id, a_size, m_size, p_size, sf)

    P = []
    M = []
    A = []
    for i in range(sf):
        for j in range(sf):
            rel_size = 1. / sf
            a, m, p = crop_maps(image_data, i*rel_size, j*rel_size, rel_size)
            P.append(p)
            M.append(m)
            A.append(a)

    A = np.array(A).transpose((0, 3, 1, 2))
    M = np.array(M).transpose((0, 3, 1, 2))
    P = np.array(P).transpose((0, 3, 1, 2))
    data_iter = mx.io.NDArrayIter(data={'a_data': A,
                                        'm_data': M,
                                        'p_data': P}, batch_size=batch_size)
    preds = mod.predict(data_iter).asnumpy()

    gg = np.zeros((n_out, l_size*sf, l_size*sf))
    for i in range(sf):  # TODO via reshape
        for j in range(sf):
            gg[:, l_size*i: l_size*(i+1),
               l_size*j: l_size*(j+1)] = preds[i*sf+j]
    preds = gg
#    preds = preds.transpose((1, 2, 0))
    assert preds.shape[0] == n_out
    return preds.astype(np.float32)
项目:resnet.mxnet    作者:TuSimple    | 项目源码 | 文件源码
def main(config):
    symbol, arg_params, aux_params = mx.model.load_checkpoint('./model/' + config.model_load_prefix, config.model_load_epoch)

    model = mx.model.FeedForward(symbol, mx.gpu(0), arg_params=arg_params, aux_params=aux_params)
    kv = mx.kvstore.create(config.kv_store)
    _, val, _ = imagenet_iterator(data_dir=config.data_dir,
                                  batch_size=config.batch_size,
                                  kv=kv)
    print model.score(val)
项目:srep    作者:Answeror    | 项目源码 | 文件源码
def coral(gpu, fold, batch_size):
    with Context():
        val = data_s21.get_inter_subject_val(fold=fold, batch_size=batch_size)

        mod = Module(
            num_gesture=8,
            coral=True,
            adabn=True,
            adabn_num_epoch=10,
            symbol_kargs=dict(
                num_filter=16,
                num_pixel=2,
                num_feature_block=2,
                num_gesture_block=0,
                num_hidden=512,
                num_bottleneck=128,
                dropout=0.5,
                num_channel=1
            ),
            context=[mx.gpu(i) for i in gpu]
        )
        mod.init_coral(
            '.cache/sigr-inter-adabn-%d-v403/model-0060.params' % fold,
            [data_s21.get_coral([i], batch_size) for i in range(10) if i != fold],
            data_s21.get_coral([fold], batch_size)
        )
        # mod.bind(data_shapes=val.provide_data, for_training=False)
        # mod.load_params('.cache/sigr-inter-%d-final/model-0060.params' % fold)

        metric = mx.metric.create('acc')
        mod.score(val, metric)
        logger.info('Fold {} accuracy: {}', fold, metric.get()[1])
项目:focal-loss    作者:unsky    | 项目源码 | 文件源码
def main():
    print ('Called with argument:', args)
    ctx = [mx.gpu(int(i)) for i in config.gpus.split(',')]
    logger, output_path = create_logger(config.output_path, args.cfg, config.dataset.image_set)
    shutil.copy2(os.path.join(curr_path, 'symbols', config.symbol + '.py'), output_path)

    prefix = os.path.join(output_path, 'rcnn')
    logging.info('########## TRAIN rcnn WITH IMAGENET INIT AND RPN DETECTION')
    train_rcnn(config, config.dataset.dataset, config.dataset.image_set, config.dataset.root_path, config.dataset.dataset_path,
               args.frequent, config.default.kvstore, config.TRAIN.FLIP, config.TRAIN.SHUFFLE, config.TRAIN.RESUME,
               ctx, config.network.pretrained, config.network.pretrained_epoch, prefix, config.TRAIN.begin_epoch,
               config.TRAIN.end_epoch, train_shared=False, lr=config.TRAIN.lr, lr_step=config.TRAIN.lr_step,
               proposal=config.dataset.proposal, logger=logger)
项目:focal-loss    作者:unsky    | 项目源码 | 文件源码
def main():
    ctx = [mx.gpu(int(i)) for i in config.gpus.split(',')]
    print args

    logger, final_output_path = create_logger(config.output_path, args.cfg, config.dataset.test_image_set)

    test_rcnn(config, config.dataset.dataset, config.dataset.test_image_set, config.dataset.root_path, config.dataset.dataset_path,
              ctx, os.path.join(final_output_path, '..', '_'.join([iset for iset in config.dataset.image_set.split('+')]), config.TRAIN.model_prefix), config.TEST.test_epoch,
              args.vis, args.ignore_cache, args.shuffle, config.TEST.HAS_RPN, config.dataset.proposal, args.thresh, logger=logger, output_path=final_output_path)
项目:focal-loss    作者:unsky    | 项目源码 | 文件源码
def main():
    print('Called with argument:', args)
    ctx = [mx.gpu(int(i)) for i in config.gpus.split(',')]
    train_net(args, ctx, config.network.pretrained, config.network.pretrained_epoch, config.TRAIN.model_prefix,
              config.TRAIN.begin_epoch, config.TRAIN.end_epoch, config.TRAIN.lr, config.TRAIN.lr_step)
项目:odnl    作者:lilhope    | 项目源码 | 文件源码
def main():
    args = parse_args()
    print('Called with argument:', args)
    ctx = [mx.gpu(int(i)) for i in args.gpus.split(',')]
    train_rcnn(args.network, args.dataset, args.image_set, args.root_path, args.dataset_path,
               args.frequent, args.kvstore, args.work_load_list, args.no_flip, args.no_shuffle, args.resume,
               ctx, args.pretrained, args.pretrained_epoch, args.prefix, args.begin_epoch, args.end_epoch,
               train_shared=args.train_shared, lr=args.lr, lr_step=args.lr_step, proposal=args.proposal)
项目:odnl    作者:lilhope    | 项目源码 | 文件源码
def main():
    args = parse_args()
    print('Called with argument:', args)
    ctx = mx.gpu(args.gpu)
    test_rpn(args.network, args.dataset, args.image_set, args.root_path, args.dataset_path,
             ctx, args.prefix, args.epoch,
             args.vis, args.shuffle, args.thresh)
项目:odnl    作者:lilhope    | 项目源码 | 文件源码
def main():
    args = parse_args()
    ctx = mx.gpu(args.gpu)
    print(args)
    test_rcnn(args.network, args.dataset, args.image_set, args.root_path, args.dataset_path,
              ctx, args.prefix, args.epoch,
              args.vis, args.shuffle, args.has_rpn, args.proposal, args.thresh)
项目:odnl    作者:lilhope    | 项目源码 | 文件源码
def main():
    args = parse_args()
    ctx = mx.gpu(args.gpu)
    print(args)
    test_rcnn(args.network, args.dataset, args.root_path, args.subset, args.split,
              ctx, args.prefix, args.epoch,
              args.vis, args.shuffle, args.has_rpn, args.proposal, args.thresh)
项目:odnl    作者:lilhope    | 项目源码 | 文件源码
def main():
    args = parse_args()
    print('Called with argument:', args)
    ctx = [mx.gpu(int(i)) for i in args.gpus.split(',')]
    train_net(args, ctx, args.pretrained, args.pretrained_epoch, args.prefix, args.begin_epoch, args.end_epoch,
              lr=args.lr, lr_step=args.lr_step)
项目:sockeye    作者:awslabs    | 项目源码 | 文件源码
def determine_context(args: argparse.Namespace, exit_stack: ExitStack) -> List[mx.Context]:
    """
    Determine the context we should run on (CPU or GPU).

    :param args: Arguments as returned by argparse.
    :param exit_stack: An ExitStack from contextlib.
    :return: A list with the context(s) to run on.
    """
    if args.use_cpu:
        logger.info("Training Device: CPU")
        context = [mx.cpu()]
    else:
        num_gpus = utils.get_num_gpus()
        check_condition(num_gpus >= 1,
                        "No GPUs found, consider running on the CPU with --use-cpu "
                        "(note: check depends on nvidia-smi and this could also mean that the nvidia-smi "
                        "binary isn't on the path).")
        if args.disable_device_locking:
            context = utils.expand_requested_device_ids(args.device_ids)
        else:
            context = exit_stack.enter_context(utils.acquire_gpus(args.device_ids, lock_dir=args.lock_dir))
        if args.batch_type == C.BATCH_TYPE_SENTENCE:
            check_condition(args.batch_size % len(context) == 0, "When using multiple devices the batch size must be "
                                                                 "divisible by the number of devices. Choose a batch "
                                                                 "size that is a multiple of %d." % len(context))
        logger.info("Training Device(s): GPU %s", context)
        context = [mx.gpu(gpu_id) for gpu_id in context]
    return context
项目:mxnet-ssd    作者:zhreshold    | 项目源码 | 文件源码
def get_detector(net, prefix, epoch, data_shape, mean_pixels, ctx, num_class,
                 nms_thresh=0.5, force_nms=True, nms_topk=400):
    """
    wrapper for initialize a detector

    Parameters:
    ----------
    net : str
        test network name
    prefix : str
        load model prefix
    epoch : int
        load model epoch
    data_shape : int
        resize image shape
    mean_pixels : tuple (float, float, float)
        mean pixel values (R, G, B)
    ctx : mx.ctx
        running context, mx.cpu() or mx.gpu(?)
    num_class : int
        number of classes
    nms_thresh : float
        non-maximum suppression threshold
    force_nms : bool
        force suppress different categories
    """
    if net is not None:
        net = get_symbol(net, data_shape, num_classes=num_class, nms_thresh=nms_thresh,
            force_nms=force_nms, nms_topk=nms_topk)
    detector = Detector(net, prefix, epoch, data_shape, mean_pixels, ctx=ctx)
    return detector
项目:Deformable-ConvNets    作者:msracver    | 项目源码 | 文件源码
def main():
    print 'Called with argument:', args
    ctx = [mx.gpu(int(i)) for i in config.gpus.split(',')]
    train_net(args, ctx, config.network.pretrained, config.network.pretrained_epoch, config.TRAIN.model_prefix,
              config.TRAIN.begin_epoch, config.TRAIN.end_epoch, config.TRAIN.lr, config.TRAIN.lr_step)
项目:Deformable-ConvNets    作者:msracver    | 项目源码 | 文件源码
def main():
    ctx = [mx.gpu(int(i)) for i in config.gpus.split(',')]
    print args

    logger, final_output_path = create_logger(config.output_path, args.cfg, config.dataset.test_image_set)

    test_rcnn(config, config.dataset.dataset, config.dataset.test_image_set, config.dataset.root_path, config.dataset.dataset_path,
              ctx, os.path.join(final_output_path, '..', '_'.join([iset for iset in config.dataset.image_set.split('+')]), config.TRAIN.model_prefix), config.TEST.test_epoch,
              args.vis, args.ignore_cache, args.shuffle, config.TEST.HAS_RPN, config.dataset.proposal, args.thresh, logger=logger, output_path=final_output_path)
项目:Deformable-ConvNets    作者:msracver    | 项目源码 | 文件源码
def main():
    print('Called with argument:', args)
    ctx = [mx.gpu(int(i)) for i in config.gpus.split(',')]
    train_net(args, ctx, config.network.pretrained, config.network.pretrained_epoch, config.TRAIN.model_prefix,
              config.TRAIN.begin_epoch, config.TRAIN.end_epoch, config.TRAIN.lr, config.TRAIN.lr_step)
项目:Deformable-ConvNets    作者:msracver    | 项目源码 | 文件源码
def main():
    print('Called with argument:', args)
    ctx = [mx.gpu(int(i)) for i in config.gpus.split(',')]
    alternate_train(args, ctx, config.network.pretrained, config.network.pretrained_epoch)
项目:Deformable-ConvNets    作者:msracver    | 项目源码 | 文件源码
def main():
    ctx = [mx.gpu(int(i)) for i in config.gpus.split(',')]
    print args

    logger, final_output_path = create_logger(config.output_path, args.cfg, config.dataset.test_image_set)

    test_rcnn(config, config.dataset.dataset, config.dataset.test_image_set, config.dataset.root_path, config.dataset.dataset_path,
              ctx, os.path.join(final_output_path, '..', '_'.join([iset for iset in config.dataset.image_set.split('+')]), config.TRAIN.model_prefix), config.TEST.test_epoch,
              args.vis, args.ignore_cache, args.shuffle, config.TEST.HAS_RPN, config.dataset.proposal, args.thresh, logger=logger, output_path=final_output_path)
项目:Deformable-ConvNets    作者:msracver    | 项目源码 | 文件源码
def main():
    print('Called with argument:', args)
    ctx = [mx.gpu(int(i)) for i in config.gpus.split(',')]
    train_net(args, ctx, config.network.pretrained, config.network.pretrained_epoch, config.TRAIN.model_prefix,
              config.TRAIN.begin_epoch, config.TRAIN.end_epoch, config.TRAIN.lr, config.TRAIN.lr_step)