我们从Python开源项目中,提取了以下6个代码示例,用于说明如何使用caffe.io()。
def __init__(self, net_proto, net_weights, device_id, input_size=None): caffe.set_mode_gpu() caffe.set_device(device_id) self._net = caffe.Net(net_proto, net_weights, caffe.TEST) input_shape = self._net.blobs['data'].data.shape if input_size is not None: input_shape = input_shape[:2] + input_size transformer = caffe.io.Transformer({'data': input_shape}) if self._net.blobs['data'].data.shape[1] == 3: transformer.set_transpose('data', (2, 0, 1)) # move image channels to outermost dimension transformer.set_mean('data', np.array([104, 117, 123])) # subtract the dataset-mean value in each channel else: pass # non RGB data need not use transformer self._transformer = transformer self._sample_shape = self._net.blobs['data'].data.shape
def set_mean(self): if self._mean_file: if type(self._mean_file) is str: # read image mean from file try: # if it is a pickle file self._mean = np.load(self._mean_file) except (IOError): blob = caffe_pb2.BlobProto() blob_str = open(self._mean_file, 'rb').read() blob.ParseFromString(blob_str) self._mean = np.array(caffe.io.blobproto_to_array(blob))[0] else: self._mean = self._mean_file self._mean = np.array(self._mean) else: self._mean = None
def getFeats(ims,net,feat_layer): net.blobs['data'].reshape(len(ims),3,227,227) transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape}) transformer.set_mean('data', IM_MEAN) transformer.set_transpose('data', (2,0,1)) transformer.set_channel_swap('data', (2,1,0)) transformer.set_raw_scale('data', 255.0) caffe_input = np.empty((len(ims),3,227,227)) for ix in range(len(ims)): caffe_input[ix,:,:,:] = transformer.preprocess('data',caffe.io.load_image(ims[ix])) net.blobs['data'].data[...] = caffe_input out = net.forward() feat = net.blobs[feat_layer].data.copy() return feat
def load_all(self): """The function to load all data and labels Give: data: the list of raw data, needs to be decompressed (e.g., raw JPEG string) labels: 0-based labels, in format of numpy array """ start = time.time() print("Start Loading Data from CSV File {}".format( self._source_fn)) try: db_ = lmdb.open(self._source_fn) data_cursor_ = db_.begin().cursor() if self._label_fn: label_db_ = lmdb.open(self._label_fn) label_cursor_ = label_db_.begin().cursor() # begin reading data if self._label_fn: label_cursor_.first() while data_cursor_.next(): value_str = data_cursor_.value() datum_ = caffe_pb2.Datum() datum_.ParseFromString(value_str) self._data.append(datum_.data) if self._label_fn: label_cursor_.next() label_datum_ = caffe_pb2.Datum() label_datum_.ParseFromString(label_cursor_.value()) label_ = caffe.io.datum_to_array(label_datum_) label_ = ":".join([str(x) for x in label_.astype(int)]) else: label_ = str(datum_.label) self._labels.appen(label_) # close all db db_.close() if self._label_fn: label_db_.close() except: raise Exception("Error in Parsing input file") end = time.time() self._labels = np.array(self._labels) print("Loading {} samples Done: Time cost {} seconds".format( len(self._data), end - start)) return self._data, self._labels