Python data_helpers 模块,load_data() 实例源码

我们从Python开源项目中,提取了以下2个代码示例,用于说明如何使用data_helpers.load_data()

项目:mxnet_tk1    作者:starimpact    | 项目源码 | 文件源码
def train_without_pretrained_embedding():
    x, y, vocab, vocab_inv = data_helpers.load_data()
    vocab_size = len(vocab)

    # randomly shuffle data
    np.random.seed(10)
    shuffle_indices = np.random.permutation(np.arange(len(y)))
    x_shuffled = x[shuffle_indices]
    y_shuffled = y[shuffle_indices]

    # split train/dev set
    x_train, x_dev = x_shuffled[:-1000], x_shuffled[-1000:]
    y_train, y_dev = y_shuffled[:-1000], y_shuffled[-1000:]
    print 'Train/Dev split: %d/%d' % (len(y_train), len(y_dev))
    print 'train shape:', x_train.shape
    print 'dev shape:', x_dev.shape
    print 'vocab_size', vocab_size

    batch_size = 50
    num_embed = 300
    sentence_size = x_train.shape[1]

    print 'batch size', batch_size
    print 'sentence max words', sentence_size
    print 'embedding size', num_embed

    cnn_model = setup_cnn_model(mx.gpu(0), batch_size, sentence_size, num_embed, vocab_size, dropout=0.5, with_embedding=False)
    train_cnn(cnn_model, x_train, y_train, x_dev, y_dev, batch_size)
项目:CNN-for-Sentence-Classification-in-Keras    作者:alexander-rakhlin    | 项目源码 | 文件源码
def load_data(data_source):
    assert data_source in ["keras_data_set", "local_dir"], "Unknown data source"
    if data_source == "keras_data_set":
        (x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_words, start_char=None,
                                                              oov_char=None, index_from=None)

        x_train = sequence.pad_sequences(x_train, maxlen=sequence_length, padding="post", truncating="post")
        x_test = sequence.pad_sequences(x_test, maxlen=sequence_length, padding="post", truncating="post")

        vocabulary = imdb.get_word_index()
        vocabulary_inv = dict((v, k) for k, v in vocabulary.items())
        vocabulary_inv[0] = "<PAD/>"
    else:
        x, y, vocabulary, vocabulary_inv_list = data_helpers.load_data()
        vocabulary_inv = {key: value for key, value in enumerate(vocabulary_inv_list)}
        y = y.argmax(axis=1)

        # Shuffle data
        shuffle_indices = np.random.permutation(np.arange(len(y)))
        x = x[shuffle_indices]
        y = y[shuffle_indices]
        train_len = int(len(x) * 0.9)
        x_train = x[:train_len]
        y_train = y[:train_len]
        x_test = x[train_len:]
        y_test = y[train_len:]

    return x_train, y_train, x_test, y_test, vocabulary_inv


# Data Preparation