我们从Python开源项目中,提取了以下2个代码示例,用于说明如何使用nltk.compat()。
def decode(): with tf.Session() as sess: # Create model and load parameters. model = create_model(sess, True) model.batch_size = 1 # We decode one sentence at a time. # Load vocabularies. src_lang_vocab_path = PATH_TO_DATA_FILES + FLAGS.src_lang + "_mapping%d.txt" % FLAGS.src_lang_vocab_size dst_lang_vocab_path = PATH_TO_DATA_FILES + FLAGS.dst_lang + "_mapping%d.txt" % FLAGS.dst_lang_vocab_size src_lang_vocab, _ = data_utils.initialize_vocabulary(src_lang_vocab_path) _, rev_dst_lang_vocab = data_utils.initialize_vocabulary(dst_lang_vocab_path) # Decode from standard input. sys.stdout.write("> ") sys.stdout.flush() sentence = sys.stdin.readline() while sentence: # Get token-ids for the input sentence. token_ids = data_utils.sentence_to_token_ids(tf.compat.as_bytes(sentence), src_lang_vocab) # Which bucket does it belong to? bucket_id = min([b for b in xrange(len(_buckets)) if _buckets[b][0] > len(token_ids)]) # Get a 1-element batch to feed the sentence to the model. encoder_inputs, decoder_inputs, target_weights = model.get_batch( {bucket_id: [(token_ids, [])]}, bucket_id) # Get output logits for the sentence. _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, True) # This is a greedy decoder - outputs are just argmaxes of output_logits. outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits] # If there is an EOS symbol in outputs, cut them at that point. if data_utils.EOS_ID in outputs: outputs = outputs[:outputs.index(data_utils.EOS_ID)] # Print out French sentence corresponding to outputs. print(" ".join([tf.compat.as_str(rev_dst_lang_vocab[output]) for output in outputs])) print("> ", end="") sys.stdout.flush() sentence = sys.stdin.readline()
def test(): """Test the translation model.""" nltk.download('punkt') with tf.Session() as sess: model = create_model(sess, True) model.batch_size = 1 # We decode one sentence at a time. # Load vocabularies. src_lang_vocab_path = PATH_TO_DATA_FILES + FLAGS.src_lang + "_mapping%d.txt" % FLAGS.src_lang_vocab_size dst_lang_vocab_path = PATH_TO_DATA_FILES + FLAGS.dst_lang + "_mapping%d.txt" % FLAGS.dst_lang_vocab_size src_lang_vocab, _ = data_utils.initialize_vocabulary(src_lang_vocab_path) _, rev_dst_lang_vocab = data_utils.initialize_vocabulary(dst_lang_vocab_path) weights = [0.25, 0.25, 0.25, 0.25] first_lang_file = open(generate_src_lang_sentences_file_name(FLAGS.src_lang)) second_lang_file = open(generate_src_lang_sentences_file_name(FLAGS.dst_lang)) total_bleu_value = 0.0 computing_bleu_iterations = 0 for first_lang_raw in first_lang_file: second_lang_gold_raw = second_lang_file.readline() # Get token-ids for the input sentence. token_ids = data_utils.sentence_to_token_ids(tf.compat.as_bytes(first_lang_raw), src_lang_vocab) # Which bucket does it belong to? try: bucket_id = min([b for b in xrange(len(_buckets)) if _buckets[b][0] > len(token_ids)]) except ValueError: continue # Get a 1-element batch to feed the sentence to the model. encoder_inputs, decoder_inputs, target_weights = model.get_batch( {bucket_id: [(token_ids, [])]}, bucket_id) # Get output logits for the sentence. _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, True) # This is a greedy decoder - outputs are just argmaxes of output_logits. outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits] # If there is an EOS symbol in outputs, cut them at that point. if data_utils.EOS_ID in outputs: outputs = outputs[:outputs.index(data_utils.EOS_ID)] # Print out sentence corresponding to outputs. model_tran_res = " ".join([tf.compat.as_str(rev_dst_lang_vocab[output]) for output in outputs]) second_lang_gold_tokens = word_tokenize(second_lang_gold_raw) model_tran_res_tokens = word_tokenize(model_tran_res) try: current_bleu_value = sentence_bleu([model_tran_res_tokens], second_lang_gold_tokens, weights) total_bleu_value += current_bleu_value computing_bleu_iterations += 1 except ZeroDivisionError: pass if computing_bleu_iterations % 10 == 0: print("BLEU value after %d iterations: %.2f" % (computing_bleu_iterations, total_bleu_value / computing_bleu_iterations)) final_bleu_value = total_bleu_value / computing_bleu_iterations print("Final BLEU value after %d iterations: %.2f" % (computing_bleu_iterations, final_bleu_value)) return