我们从Python开源项目中,提取了以下2个代码示例,用于说明如何使用keras.callbacks.CallbackList()。
def init_callbacks(self, for_worker=False): """Prepares all keras callbacks to be used in training. Automatically attaches a History callback to the end of the callback list. If for_worker is True, leaves out callbacks that only make sense with validation enabled.""" import keras.callbacks as cbks remove_for_worker = [cbks.EarlyStopping, cbks.ModelCheckpoint] if for_worker: for obj in remove_for_worker: self.callbacks_list = [ c for c in self.callbacks_list if not isinstance(c, obj) ] self.model.history = cbks.History() self.callbacks = cbks.CallbackList( self.callbacks_list + [self.model.history] ) # it's possible to callback a different model than self # (used by Sequential models) if hasattr(self.model, 'callback_model') and self.model.callback_model: self.callback_model = self.model.callback_model else: self.callback_model = self.model self.callbacks.set_model(self.callback_model) self.callback_model.stop_training = False
def fit(self, dataloader, nb_iter=None, nb_epoch=None, iter_per_epoch=None, callbacks=[], verbose=0): """Trains the underlying Keras model. Args: dataloader (StandardDataLoader): Manages the loading of data to model. nb_iter (int): The number of iterations to train the model. nb_epoch (int): The number of epochs to train the model. iter_per_epoch (int): Defines the number of iterations per epoch. callbacks (list): List of Keras callbacks to run during training. """ nb_iter, iter_per_epoch = self._get_iterations( nb_iter, nb_epoch, iter_per_epoch) callbacks = CallbackList(callbacks) callbacks._set_model(self) callbacks.on_train_begin() try: epoch = 0 self.stop_training = False for i in xrange(nb_iter): # Begin epoch if i % iter_per_epoch == 0: callbacks.on_epoch_begin(epoch) # Execution callbacks.on_batch_begin(i) if verbose > 0: import time time.sleep(0.001) j = i % iter_per_epoch perc = int(100 * (j + 1) /iter_per_epoch) prog = ''.join(['='] * (perc/2)) string = "[{:50s}] {:3d}%\r".format(prog, perc) sys.stdout.write(string); sys.stdout.flush() losses = self.keras_model.train_on_batch( *dataloader.get_training_batch()) callbacks.on_batch_end(i) # End epoch if (i + 1) % iter_per_epoch == 0: callbacks.on_epoch_end(epoch, logs={'losses': losses}) epoch += 1 if self.stop_training: break except KeyboardInterrupt: print "\n[BayesNet] Abort: KeyboardInterrupt" raise callbacks.on_train_end()