我们从Python开源项目中,提取了以下3个代码示例,用于说明如何使用chainer.datasets()。
def get_mnist(): train, test = chainer.datasets.get_mnist(ndim=3) train_data = [t for t in train] test_data = [t for t in test] train_data = np.array(train_data) test_data = np.array(test_data) train_data = np.expand_dims(train_data, 1) test_data = np.expand_dims(test_data, 1) train_xs = train_data[:,:,0].T train_ys = train_data[:,:,1].T test_xs = test_data[:,:,0].T test_ys = test_data[:,:,1].T train = TupleDataset(*(train_xs.tolist() + train_ys.tolist())) test = TupleDataset(*(test_xs.tolist() + test_ys.tolist())) return train,test
def score_core(self, X, y=None, sample_weight=None, batchsize=16): # Type check X, y = self._check_X_y(X, y) # during GridSearch, which only assumes score(X, y) interface. if y is None: test = X if isinstance(test, numpy.ndarray): # TODO: reivew print('score_core numpy.ndarray received...') test = chainer.datasets.TupleDataset(test) else: test = chainer.datasets.TupleDataset(X, y) # For Classifier # `accuracy` is calculated as score, using `forward_batch` # For regressor # `loss` is calculated as score, using `forward_batch` self.forward_batch(test, batchsize=batchsize, retain_inputs=False, calc_score=True) return self.total_score
def main(args): # get datasets source_train, source_test = chainer.datasets.get_svhn() target_train, target_test = chainer.datasets.get_mnist(ndim=3, rgb_format=True) source = source_train, source_test # resize mnist to 32x32 def transform(in_data): img, label = in_data img = resize(img, (32, 32)) return img, label target_train = TransformDataset(target_train, transform) target_test = TransformDataset(target_test, transform) target = target_train, target_test # load pretrained source, or perform pretraining pretrained = os.path.join(args.output, args.pretrained_source) if not os.path.isfile(pretrained): source_cnn = pretrain_source_cnn(source, args) else: source_cnn = Loss(num_classes=10) serializers.load_npz(pretrained, source_cnn) # how well does this perform on target domain? test_pretrained_on_target(source_cnn, target, args) # initialize the target cnn (do not use source_cnn.copy) target_cnn = Loss(num_classes=10) # copy parameters from source cnn to target cnn target_cnn.copyparams(source_cnn) train_target_cnn(source, target, source_cnn, target_cnn, args)