我们从Python开源项目中,提取了以下5个代码示例,用于说明如何使用torch.backends.cudnn.RNNDescriptor()。
def init_rnn_descriptor(fn, handle): dropout_desc_name = 'desc_' + str(torch.cuda.current_device()) dropout_p = fn.dropout if fn.train else 0 if (dropout_desc_name not in fn.dropout_state) or (fn.dropout_state[dropout_desc_name].get() is None): fn.dropout_state[dropout_desc_name] = Unserializable( cudnn.DropoutDescriptor(handle, dropout_p, fn.dropout_seed) ) dropout_desc = fn.dropout_state[dropout_desc_name].get() dropout_desc.set_dropout(dropout_p, fn.dropout_seed) return cudnn.RNNDescriptor( handle, fn.hidden_size, fn.num_layers, dropout_desc, fn.input_mode, fn.bidirectional, fn.mode, fn.datatype )
def init_rnn_descriptor(fn): rnn_desc = cudnn.RNNDescriptor() rnn_desc.set( fn.hidden_size, fn.num_layers, fn.dropout_desc, fn.input_mode, fn.bidirectional, fn.mode, fn.datatype ) return rnn_desc