我们从Python开源项目中,提取了以下7个代码示例,用于说明如何使用torchvision.models.resnet18()。
def __init__(self,embedding_size,num_classes, checkpoint=None): super(FaceModelCenter, self).__init__() self.model = resnet18() self.model.avgpool = None self.model.fc1 = nn.Linear(512*3*3, 512) self.model.fc2 = nn.Linear(512, embedding_size) self.model.classifier = nn.Linear(embedding_size, num_classes) self.centers = torch.zeros(num_classes, embedding_size).type(torch.FloatTensor) self.num_classes = num_classes self.apply(self.weights_init) if checkpoint is not None: # Check if there are the same number of classes if list(checkpoint['state_dict'].values())[-1].size(0) == num_classes: self.load_state_dict(checkpoint['state_dict']) self.centers = checkpoint['centers'] else: own_state = self.state_dict() for name, param in checkpoint['state_dict'].items(): if "classifier" not in name: if isinstance(param, Parameter): # backwards compatibility for serialized parameters param = param.data own_state[name].copy_(param)
def GetPretrainedModel(params, num_classes): if params['model'] == 'resnet18': model = models.resnet18(pretrained=True) elif params['model'] == 'resnet34': model = models.resnet34(pretrained=True) elif params['model'] == 'resnet50': model = models.resnet50(pretrained=True) elif params['model'] == 'resnet101': model = models.resnet101(pretrained=True) elif params['model'] == 'resnet152': model = models.resnet152(pretrained=True) else: raise ValueError('Unknown model type') num_features = model.fc.in_features model.fc = SigmoidLinear(num_features, num_classes) return model
def __init__(self): super().__init__() self.pretrained_model = models.resnet18(pretrained=True) classifier = [ nn.Linear(self.pretrained_model.fc.in_features, 17) ] self.classifier = nn.Sequential(*classifier) self.pretrained_model.fc = self.classifier
def resnet18_weldon(num_classes, pretrained=True, kmax=1, kmin=None): model = models.resnet18(pretrained) pooling = WeldonPool2d(kmax, kmin) return ResNetWSL(model, num_classes, pooling=pooling)
def __init__(self,embedding_size,num_classes,pretrained=False): super(FaceModel, self).__init__() self.model = resnet18(pretrained) self.embedding_size = embedding_size self.model.fc = nn.Linear(512*3*3, self.embedding_size) self.model.classifier = nn.Linear(self.embedding_size, num_classes)
def resnet18(num_classes=1000, pretrained='imagenet'): """Constructs a ResNet-18 model. """ model = models.resnet18(pretrained=False) if pretrained is not None: settings = pretrained_settings['resnet18'][pretrained] model = load_pretrained(model, num_classes, settings) model = modify_resnets(model) return model
def _GetArguments(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("--interactive", action="store_true", default=False, help="Run the script in an interactive mode") parser.add_argument("--verbose", action="store_true", default=False, help="Visualize all logs") parser.add_argument("--voc_devkit_dir", default="data/VOCdevkit", help="Root directory of VOC development kit") parser.add_argument("--voc_version", default="VOC2007", help="Target VOC dataset version") parser.add_argument("--model", default="resnet18", help="Pretrained model") parser.add_argument("--batch_size", default=100, type=int, help="Batch size") parser.add_argument("--num_data_loading_workers", default=4, type=int, help="Number of data loading workers") parser.add_argument("--num_epochs", default=50, type=int, help="Number of epochs for training") parser.add_argument("--initial_learning_rate", default=0.1, type=float, help="Initial learning rate for SGD") parser.add_argument("--learning_rate_decay_epoch", default=50, type=int, help="How frequently the learning rate will be decayed") parser.add_argument("--momentum", default=0.9, type=float, help="Momentum for learning rate of SGD") parser.add_argument("--seed", default=111, type=int, help="Random seed") parser.add_argument("--save_dir", default="temp_result", help="Directory for saving results of train / test") args = parser.parse_args() params = vars(args) print (json.dumps(params, indent=2)) return params