Python data 模块,BaseTransform() 实例源码

我们从Python开源项目中,提取了以下1个代码示例,用于说明如何使用data.BaseTransform()

项目:realtime-action-detection    作者:gurkirt    | 项目源码 | 文件源码
def main():

    means = (104, 117, 123)  # only support voc now

    exp_name = 'CONV-SSD-{}-{}-bs-{}-{}-lr-{:05d}'.format(args.dataset, args.input_type,
                            args.batch_size, args.basenet[:-14], int(args.lr * 100000))

    args.save_root += args.dataset+'/'
    args.data_root += args.dataset+'/'
    args.listid = '01' ## would be usefull in JHMDB-21
    print('Exp name', exp_name, args.listid)
    for iteration in [int(itr) for itr in args.eval_iter.split(',')]:
        log_file = open(args.save_root + 'cache/' + exp_name + "/testing-{:d}.log".format(iteration), "w", 1)
        log_file.write(exp_name + '\n')
        trained_model_path = args.save_root + 'cache/' + exp_name + '/ssd300_ucf24_' + repr(iteration) + '.pth'
        log_file.write(trained_model_path+'\n')
        num_classes = len(CLASSES) + 1  #7 +1 background
        net = build_ssd(300, num_classes)  # initialize SSD
        net.load_state_dict(torch.load(trained_model_path))
        net.eval()
        if args.cuda:
            net = net.cuda()
            cudnn.benchmark = True
        print('Finished loading model %d !' % iteration)
        # Load dataset
        dataset = UCF24Detection(args.data_root, 'test', BaseTransform(args.ssd_dim, means), AnnotationTransform(),
                                 input_type=args.input_type, full_test=True)
        # evaluation
        torch.cuda.synchronize()
        tt0 = time.perf_counter()
        log_file.write('Testing net \n')
        mAP, ap_all, ap_strs = test_net(net, args.save_root, exp_name, args.input_type, dataset, iteration, num_classes)
        for ap_str in ap_strs:
            print(ap_str)
            log_file.write(ap_str + '\n')
        ptr_str = '\nMEANAP:::=>' + str(mAP) + '\n'
        print(ptr_str)
        log_file.write(ptr_str)

        torch.cuda.synchronize()
        print('Complete set time {:0.2f}'.format(time.perf_counter() - tt0))
        log_file.close()