我们从Python开源项目中,提取了以下3个代码示例,用于说明如何使用keras.utils.plot_model()。
def check_model(path=MODEL_PATH, file=SAMPLE_CSV_FILE, nsamples=2): ''' see predictions generated for the training dataset ''' # load model model = load_model(path) # load data data, dic = get_data(file) rows, questions, true_answers = encode_data(data, dic) # visualize model graph # plot_model(model, to_file='tableqa_model.png') # predict answers prediction = model.predict([rows[:nsamples], questions[:nsamples]]) print prediction predicted_answers = [[np.argmax(character) for character in sample] for sample in prediction] print predicted_answers print true_answers[:nsamples] # one hot encode answers # true_answers = [to_categorical(answer, num_classes=len(dic)) for answer in answers[:nsamples]] # decode chars from char ids int inv_dic = {v: k for k, v in dic.iteritems()} for i in xrange(nsamples): print '\n' # print 'Predicted answer: ' + ''.join([dic[char] for char in sample]) print 'Table: ' + ''.join([inv_dic[char_id] for char_id in rows[i] if char_id != 0]) print 'Question: ' + ''.join([inv_dic[char_id] for char_id in questions[i] if char_id != 0]) print 'Answer(correct): ' + ''.join([inv_dic[char_id] for char_id in true_answers[i] if char_id != 0]) print 'Answer(predicted): ' + ''.join([inv_dic[char_id] for char_id in predicted_answers[i] if char_id != 0])
def rcnn_mtl(processed_datasets, index_embedding, params): start = datetime.datetime.now() x_trains, y_trains, x_tests, y_tests = processed_datasets mtl_model, single_models = build_models(params, index_embedding) print(mtl_model.summary()) # plot_model(mtl_model, to_file='mtl_model.png', show_shapes=True) itera = 0 batch_input = {} batch_output = {} batch_size = params['batch_size'] iterations = params['iterations'] sys.stdout.write('\ntotal iterations: {}'.format(iterations)) while (itera < iterations): generate_batch_data(batch_input, batch_output, batch_size, x_trains, y_trains) mtl_model.train_on_batch(batch_input, batch_output) itera += 1 if (itera > 200 and itera % 100 == 0): sys.stdout.write('\n\ncurrent iteration: {}'.format(itera)) # evaluate(single_models, x_trains, y_trains, 'train') evaluate(single_models, x_tests, y_tests, 'test') sys.stdout.flush() if (itera >= 500): save_predictions(single_models, x_tests, params['prediction_path']) # save_models(single_models, params['save_model_path']) end = datetime.datetime.now() sys.stdout.write('\nused time: {}\n'.format(end - start))
def train_and_evaluate(train, test, intents_lookup, save=False): validation_data = None train_inputs, train_labels = prepare_inputs_and_outputs(train, intents_lookup) if test: test_inputs, test_labels = prepare_inputs_and_outputs(test, intents_lookup) validation_data = test_inputs, test_labels print('Number of sentences for each intent, train and test') print([key for key in intents_lookup]) print(train_labels.sum(axis=0)) if test: print(test_labels.sum(axis=0)) model = create_model(len(intents_lookup)) # first iteration # model.summary() # this requires graphviz binaries also #plot_model(model, to_file=MODEL_OUTPUT_FOLDER + '/model.png', show_shapes=True) history = model.fit(train_inputs, train_labels, validation_data=validation_data, epochs=MAX_ITERATIONS, batch_size=50) # keep only f1_scores history = {'train': np.array(history.history['f1_score']), 'test': np.array(history.history.get('val_f1_score', []))} # compute f1 score weighted by support y_pred_train = model.predict(train_inputs) f1_train = f1_score(train_labels.argmax(axis=1), y_pred_train.argmax(axis=1), average='weighted') if test: y_pred_test = model.predict(test_inputs) f1_test = f1_score(test_labels.argmax(axis=1), y_pred_test.argmax(axis=1), average='weighted') else: f1_test = None # generate confusion matrix # confusion = utils.my_confusion_matrix(test_labels.argmax( # axis=1), y_pred_test.argmax(axis=1), len(intents_lookup)) print(f1_test, f1_train) if save: model.save(MODEL_OUTPUT_FOLDER + '/model.h5') stats = {} stats['model_name'] = MODEL_NAME stats['model'] = model.get_config() with open(MODEL_OUTPUT_FOLDER+'/stats.json', 'w+') as stats_file: json.dump(stats, stats_file) return history, f1_test, f1_train