我们从Python开源项目中,提取了以下5个代码示例,用于说明如何使用torchvision.models.__dict__()。
def load_defined_model(name, num_classes): model = models.__dict__[name](num_classes=num_classes) #Densenets don't (yet) pass on num_classes, hack it in for 169 if name == 'densenet169': model = torchvision.models.DenseNet(num_init_features=64, growth_rate=32, \ block_config=(6, 12, 32, 32), num_classes=num_classes) pretrained_state = model_zoo.load_url(model_urls[name]) #Diff diff = [s for s in diff_states(model.state_dict(), pretrained_state)] print("Replacing the following state from initialized", name, ":", \ [d[0] for d in diff]) for name, value in diff: pretrained_state[name] = value assert len([s for s in diff_states(model.state_dict(), pretrained_state)]) == 0 #Merge model.load_state_dict(pretrained_state) return model, diff
def get_cnn(self, arch, pretrained): """Load a pretrained CNN and parallelize over GPUs """ if pretrained: print("=> using pre-trained model '{}'".format(arch)) model = models.__dict__[arch](pretrained=True) else: print("=> creating model '{}'".format(arch)) model = models.__dict__[arch]() if arch.startswith('alexnet') or arch.startswith('vgg'): model.features = nn.DataParallel(model.features) model.cuda() else: model = nn.DataParallel(model).cuda() return model
def load_defined_model(path, num_classes,name): model = models.__dict__[name](num_classes=num_classes) pretrained_state = torch.load(path) new_pretrained_state= OrderedDict() for k, v in pretrained_state['state_dict'].items(): layer_name = k.replace("module.", "") new_pretrained_state[layer_name] = v #Diff diff = [s for s in diff_states(model.state_dict(), new_pretrained_state)] if(len(diff)!=0): print("Mismatch in these layers :", name, ":", [d[0] for d in diff]) assert len(diff) == 0 #Merge model.load_state_dict(new_pretrained_state) return model #Load the model
def _nn_forward_hook(self, module, input, output, name=''): if type(output) is list: self.blobs[name] = [o.data.clone() for o in output] else: self.blobs[name] = output.data.clone() # @staticmethod # def _load_model_config(model_def): # if isinstance(model_def, torch.nn.Module): # # elif '.' not in os.path.basename(model_def): # import torchvision.models as models # if model_def not in models.__dict__: # raise KeyError('Model {} does not exist in pytorchs model zoo.'.format(model_def)) # print('Loading model {} from pytorch model zoo'.format(model_def)) # return models.__dict__[model_def](pretrained=True) # else: # print('Loading model from {}'.format(model_def)) # if model_def.endswith('.t7'): # return load_legacy_model(model_def) # else: # return torch.load(model_def) # # # if type(model_cfg) == str: # if not os.path.exists(model_cfg): # try: # class_ = getattr(applications, model_cfg) # return class_(weights=model_weights) # except AttributeError: # available_mdls = [attr for attr in dir(applications) if callable(getattr(applications, attr))] # raise ValueError('Could not load pretrained model with key {}. ' # 'Available models: {}'.format(model_cfg, ', '.join(available_mdls))) # # with open(model_cfg, 'r') as fileh: # try: # return model_from_json(fileh) # except ValueError: # pass # # try: # return model_from_yaml(fileh) # except ValueError: # pass # # raise ValueError('Could not load model from configuration file {}. ' # 'Make sure the path is correct and the file format is yaml or json.'.format(model_cfg)) # elif type(model_cfg) == dict: # return Model.from_config(model_cfg) # elif type(model_cfg) == list: # return Sequential.from_config(model_cfg) # # raise ValueError('Could not load model from configuration object of type {}.'.format(type(model_cfg)))