我正在尝试创建一个自定义数据生成器,但不知道如何yield在__getitem__方法内部将函数与无限循环结合在一起。
yield
__getitem__
编辑 :答案后,我意识到我正在使用的代码是Sequence不需要yield声明的。
Sequence
目前,我正在返回多张图片,并附上一条return声明:
return
class DataGenerator(tensorflow.keras.utils.Sequence): def __init__(self, files, labels, batch_size=32, shuffle=True, random_state=42): 'Initialization' self.files = files self.labels = labels self.batch_size = batch_size self.shuffle = shuffle self.random_state = random_state self.on_epoch_end() def __len__(self): return int(np.floor(len(self.files) / self.batch_size)) def __getitem__(self, index): # Generate indexes of the batch indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size] files_batch = [self.files[k] for k in indexes] y = [self.labels[k] for k in indexes] # Generate data x = self.__data_generation(files_batch) return x, y def on_epoch_end(self): 'Updates indexes after each epoch' self.indexes = np.arange(len(self.files)) if self.shuffle == True: np.random.seed(self.random_state) np.random.shuffle(self.indexes) def __data_generation(self, files): imgs = [] for img_file in files: img = cv2.imread(img_file, -1) ############### # Augment image ############### imgs.append(img) return imgs
在本文中,我看到了yield它在无限循环中使用的情况。我不太了解这种语法。循环如何逃逸?
您正在使用Sequence API,该API与普通生成器的工作原理有所不同。在生成器函数中,您将使用yield关键字在循环内执行迭代while True:,因此,每次Keras调用生成器时,它都会获取一批数据,并自动环绕数据的末尾。
while True:
但是在序列中,函数有一个index参数__getitem__,因此不需要迭代或不需要迭代yield,这由Keras为您执行。这样可以使序列可以使用多重处理并行运行,而这对于旧的生成器函数是不可能的。
index
因此,您以正确的方式行事,无需任何更改。