我们从Python开源项目中,提取了以下2个代码示例,用于说明如何使用data_helpers.load_data()。
def train_without_pretrained_embedding(): x, y, vocab, vocab_inv = data_helpers.load_data() vocab_size = len(vocab) # randomly shuffle data np.random.seed(10) shuffle_indices = np.random.permutation(np.arange(len(y))) x_shuffled = x[shuffle_indices] y_shuffled = y[shuffle_indices] # split train/dev set x_train, x_dev = x_shuffled[:-1000], x_shuffled[-1000:] y_train, y_dev = y_shuffled[:-1000], y_shuffled[-1000:] print 'Train/Dev split: %d/%d' % (len(y_train), len(y_dev)) print 'train shape:', x_train.shape print 'dev shape:', x_dev.shape print 'vocab_size', vocab_size batch_size = 50 num_embed = 300 sentence_size = x_train.shape[1] print 'batch size', batch_size print 'sentence max words', sentence_size print 'embedding size', num_embed cnn_model = setup_cnn_model(mx.gpu(0), batch_size, sentence_size, num_embed, vocab_size, dropout=0.5, with_embedding=False) train_cnn(cnn_model, x_train, y_train, x_dev, y_dev, batch_size)
def load_data(data_source): assert data_source in ["keras_data_set", "local_dir"], "Unknown data source" if data_source == "keras_data_set": (x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_words, start_char=None, oov_char=None, index_from=None) x_train = sequence.pad_sequences(x_train, maxlen=sequence_length, padding="post", truncating="post") x_test = sequence.pad_sequences(x_test, maxlen=sequence_length, padding="post", truncating="post") vocabulary = imdb.get_word_index() vocabulary_inv = dict((v, k) for k, v in vocabulary.items()) vocabulary_inv[0] = "<PAD/>" else: x, y, vocabulary, vocabulary_inv_list = data_helpers.load_data() vocabulary_inv = {key: value for key, value in enumerate(vocabulary_inv_list)} y = y.argmax(axis=1) # Shuffle data shuffle_indices = np.random.permutation(np.arange(len(y))) x = x[shuffle_indices] y = y[shuffle_indices] train_len = int(len(x) * 0.9) x_train = x[:train_len] y_train = y[:train_len] x_test = x[train_len:] y_test = y[train_len:] return x_train, y_train, x_test, y_test, vocabulary_inv # Data Preparation