我正在训练一个与本示例非常相似的CNN进行图像分割。图片尺寸为1500x1500x1,标签尺寸相同。
定义CNN结构之后,并按照以下代码示例所述启动会话:(conv_net_test.py)
conv_net_test.py
with tf.Session() as sess: sess.run(init) summ = tf.train.SummaryWriter('/tmp/logdir/', sess.graph_def) step = 1 print ("import data, read from read_data_sets()...") #Data defined by me, returns a DataSet object with testing and training images and labels for segmentation problem. data = import_data_test.read_data_sets('Dataset') # Keep training until reach max iterations while step * batch_size < training_iters: batch_x, batch_y = data.train.next_batch(batch_size) print ("running backprop for step %d" % step) batch_x = batch_x.reshape(batch_size, n_input, n_input, n_channels) batch_y = batch_y.reshape(batch_size, n_input, n_input, n_channels) batch_y = np.int64(batch_y) sess.run(optimizer, feed_dict={x: batch_x, y: batch_y, keep_prob: dropout}) if step % display_step == 0: # Calculate batch loss and accuracy #pdb.set_trace() loss, acc = sess.run([loss, accuracy], feed_dict={x: batch_x, y: batch_y, keep_prob: 1.}) step += 1 print "Optimization Finished"
我遇到了以下TypeError(下面的stacktrace):
conv_net_test.py in <module>() 178 #pdb.set_trace() --> 179 loss, acc = sess.run([loss, accuracy], feed_dict={x: batch_x, y: batch_y, keep_prob: 1.}) 180 step += 1 181 print "Optimization Finished!" tensorflow/python/client/session.pyc in run(self, fetches, feed_dict, options, run_metadata) 370 try: 371 result = self._run(None, fetches, feed_dict, options_ptr, --> 372 run_metadata_ptr) 373 if run_metadata: 374 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr) tensorflow/python/client/session.pyc in _run(self, handle, fetches, feed_dict, options, run_metadata) 582 583 # Validate and process fetches. --> 584 processed_fetches = self._process_fetches(fetches) 585 unique_fetches = processed_fetches[0] 586 target_list = processed_fetches[1] tensorflow/python/client/session.pyc in _process_fetches(self, fetches) 538 raise TypeError('Fetch argument %r of %r has invalid type %r, ' 539 'must be a string or Tensor. (%s)' --> 540 % (subfetch, fetch, type(subfetch), str(e))) TypeError: Fetch argument 1.4415792e+2 of 1.4415792e+2 has invalid type <type 'numpy.float32'>, must be a string or Tensor. (Can not convert a float32 into a Tensor or Operation.)
在这一点上,我很沮丧。也许这是转换类型的简单情况,但是我不确定如何/在哪里。另外,为什么损失必须是字符串?(假设此错误一经修复,也会为准确性弹出相同的错误)。
任何帮助表示赞赏!
在使用的地方loss = sess.run(loss),您要在python中重新定义变量loss。
loss = sess.run(loss)
loss
第一次运行会很好。第二次,您将尝试执行以下操作:
sess.run(1.4415792e+2)
因为loss现在是浮点数。
您应该使用其他名称,例如:
loss_val, acc = sess.run([loss, accuracy], feed_dict={x: batch_x, y: batch_y, keep_prob: 1.})