我们从Python开源项目中,提取了以下3个代码示例,用于说明如何使用caffe.proto.caffe_pb2.LabelMap()。
def __init__(self): # load MS COCO labels labelmap_file = os.path.join(CAFFE_ROOT, LABEL_MAP) file = open(labelmap_file, 'r') self._labelmap = caffe_pb2.LabelMap() text_format.Merge(str(file.read()), self._labelmap) model_def = os.path.join(CAFFE_ROOT, PROTO_TXT) model_weights = os.path.join(CAFFE_ROOT, CAFFE_MODEL) self._net = caffe.Net(model_def, model_weights, caffe.TEST) self._transformer = caffe.io.Transformer( {'data': self._net.blobs['data'].data.shape}) self._transformer.set_transpose('data', (2, 0, 1)) self._transformer.set_mean('data', np.array([104, 117, 123])) self._transformer.set_raw_scale('data', 255) self._transformer.set_channel_swap('data', (2, 1, 0)) # set net to batch size of 1 image_resize = IMAGE_SIZE self._net.blobs['data'].reshape(1, 3, image_resize, image_resize)
def labelmap(labelmap_file, label_info): labelmap = caffe_pb2.LabelMap() for i in range(len(label_info)): labelmapitem = caffe_pb2.LabelMapItem() labelmapitem.name = label_info[i]['name'] labelmapitem.label = label_info[i]['label'] labelmapitem.display_name = label_info[i]['display_name'] labelmap.item.add().MergeFrom(labelmapitem) with open(labelmap_file, 'w') as f: f.write(str(labelmap))
def detection(img, net, transformer, labels_file): im = caffe.io.load_image(img) net.blobs['data'].data[...] = transformer.preprocess('data', im) start = time.clock() # ???? net.forward() end = time.clock() print('detection time: %f s' % (end - start)) # ???????? file = open(labels_file, 'r') labelmap = caffe_pb2.LabelMap() text_format.Merge(str(file.read()), labelmap) loc = net.blobs['detection_out'].data[0][0] confidence_threshold = 0.5 for l in range(len(loc)): if loc[l][2] >= confidence_threshold: xmin = int(loc[l][3] * im.shape[1]) ymin = int(loc[l][4] * im.shape[0]) xmax = int(loc[l][5] * im.shape[1]) ymax = int(loc[l][6] * im.shape[0]) img = np.zeros((512, 512, 3), np.uint8) # ????????? cv2.rectangle(im, (xmin, ymin), (xmax, ymax), (55 / 255.0, 255 / 255.0, 155 / 255.0), 2) # ?????? class_name = labelmap.item[int(loc[l][1])].display_name # text_font = cv2.cv.InitFont(cv2.cv.CV_FONT_HERSHEY_SCRIPT_SIMPLEX, 1, 1, 0, 3, 8) cv2.putText(im, class_name, (xmin, ymax), cv2.cv.CV_FONT_HERSHEY_SIMPLEX, 1, (55, 255, 155), 2) # ???? plt.imshow(im, 'brg') plt.show() #CPU?GPU????