我们从Python开源项目中,提取了以下6个代码示例,用于说明如何使用model.eval()。
def evaluate(data_source): # Turn on evaluation mode which disables dropout. model.eval() total_loss = 0 ntokens = len(corpus.dictionary) hidden = model.init_hidden(eval_batch_size) for i in range(0, data_source.size(0) - 1, args.bptt): data, targets = get_batch(data_source, i, evaluation=True) output, hidden = model(data, hidden) output_flat = output.view(-1, ntokens) total_loss += len(data) * criterion(output_flat, targets).data hidden = repackage_hidden(hidden) return total_loss[0] / len(data_source)
def evaluate(data_source, batch_size=10): # Turn on evaluation mode which disables dropout. if args.model == 'QRNN': model.reset() model.eval() total_loss = 0 ntokens = len(corpus.dictionary) hidden = model.init_hidden(batch_size) for i in range(0, data_source.size(0) - 1, args.bptt): data, targets = get_batch(data_source, i, args, evaluation=True) output, hidden = model(data, hidden) output_flat = output.view(-1, ntokens) total_loss += len(data) * criterion(output_flat, targets).data hidden = repackage_hidden(hidden) return total_loss[0] / len(data_source)
def evaluate(data_source, batch_size=10): # Turn on evaluation mode which disables dropout. model.eval() if args.model == 'QRNN': model.reset() total_loss = 0 ntokens = len(corpus.dictionary) hidden = model.init_hidden(batch_size) for i in range(0, data_source.size(0) - 1, args.bptt): data, targets = get_batch(data_source, i, args, evaluation=True) output, hidden = model(data, hidden) output_flat = output.view(-1, ntokens) total_loss += len(data) * criterion(output_flat, targets).data hidden = repackage_hidden(hidden) return total_loss[0] / len(data_source)
def evaluate(data_source, batch_size=10, window=args.window): # Turn on evaluation mode which disables dropout. if args.model == 'QRNN': model.reset() model.eval() total_loss = 0 ntokens = len(corpus.dictionary) hidden = model.init_hidden(batch_size) next_word_history = None pointer_history = None for i in range(0, data_source.size(0) - 1, args.bptt): if i > 0: print(i, len(data_source), math.exp(total_loss / i)) data, targets = get_batch(data_source, i, evaluation=True, args=args) output, hidden, rnn_outs, _ = model(data, hidden, return_h=True) rnn_out = rnn_outs[-1].squeeze() output_flat = output.view(-1, ntokens) ### # Fill pointer history start_idx = len(next_word_history) if next_word_history is not None else 0 next_word_history = torch.cat([one_hot(t.data[0], ntokens) for t in targets]) if next_word_history is None else torch.cat([next_word_history, torch.cat([one_hot(t.data[0], ntokens) for t in targets])]) #print(next_word_history) pointer_history = Variable(rnn_out.data) if pointer_history is None else torch.cat([pointer_history, Variable(rnn_out.data)], dim=0) #print(pointer_history) ### # Built-in cross entropy # total_loss += len(data) * criterion(output_flat, targets).data[0] ### # Manual cross entropy # softmax_output_flat = torch.nn.functional.softmax(output_flat) # soft = torch.gather(softmax_output_flat, dim=1, index=targets.view(-1, 1)) # entropy = -torch.log(soft) # total_loss += len(data) * entropy.mean().data[0] ### # Pointer manual cross entropy loss = 0 softmax_output_flat = torch.nn.functional.softmax(output_flat) for idx, vocab_loss in enumerate(softmax_output_flat): p = vocab_loss if start_idx + idx > window: valid_next_word = next_word_history[start_idx + idx - window:start_idx + idx] valid_pointer_history = pointer_history[start_idx + idx - window:start_idx + idx] logits = torch.mv(valid_pointer_history, rnn_out[idx]) theta = args.theta ptr_attn = torch.nn.functional.softmax(theta * logits).view(-1, 1) ptr_dist = (ptr_attn.expand_as(valid_next_word) * valid_next_word).sum(0).squeeze() lambdah = args.lambdasm p = lambdah * ptr_dist + (1 - lambdah) * vocab_loss ### target_loss = p[targets[idx].data] loss += (-torch.log(target_loss)).data[0] total_loss += loss / batch_size ### hidden = repackage_hidden(hidden) next_word_history = next_word_history[-window:] pointer_history = pointer_history[-window:] return total_loss / len(data_source) # Load the best saved model.