您好,登录后才能下订单哦!
密码登录
登录注册
点击 登录注册 即表示同意《亿速云用户服务条款》
本篇内容介绍了“PyTorch frozen怎么使用”的有关知识,在实际案例的操作过程中,不少人都会遇到这样的困境,接下来就让小编带领大家学习一下如何处理这些情况吧!希望大家仔细阅读,能够学有所成!
# ============================ step 2/5 模型 ============================ # 1/3 构建模型 resnet18_ft = models.resnet18() # 2/3 加载参数 # flag = 0 flag = 1 if flag: path_pretrained_model = os.path.join(BASEDIR, "..", "..", "data/resnet18-5c106cde.pth") state_dict_load = torch.load(path_pretrained_model) resnet18_ft.load_state_dict(state_dict_load) # 3/3 替换fc层 num_ftrs = resnet18_ft.fc.in_features resnet18_ft.fc = nn.Linear(num_ftrs, classes) resnet18_ft.to(device)
# ============================ step 2/5 模型 ============================ # 1/3 构建模型 resnet18_ft = models.resnet18() # 2/3 加载参数 # flag = 0 flag = 1 if flag: path_pretrained_model = os.path.join(BASEDIR, "..", "..", "data/resnet18-5c106cde.pth") state_dict_load = torch.load(path_pretrained_model) resnet18_ft.load_state_dict(state_dict_load) # 法1 : 冻结卷积层 flag_m1 = 0 # flag_m1 = 1 if flag_m1: for param in resnet18_ft.parameters(): param.requires_grad = False print("conv1.weights[0, 0, ...]:\n {}".format(resnet18_ft.conv1.weight[0, 0, ...])) # 3/3 替换fc层 num_ftrs = resnet18_ft.fc.in_features resnet18_ft.fc = nn.Linear(num_ftrs, classes) resnet18_ft.to(device)
# -*- coding: utf-8 -*- """ # @brief : 模型finetune方法 """ import os import numpy as np import torch import torch.nn as nn from torch.utils.data import DataLoader import torchvision.transforms as transforms import torch.optim as optim from matplotlib import pyplot as plt from tools.my_dataset import AntsDataset from tools.common_tools2 import set_seed import torchvision.models as models import torchvision BASEDIR = os.path.dirname(os.path.abspath(__file__)) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("use device :{}".format(device)) set_seed(1) # 设置随机种子 label_name = {"ants": 0, "bees": 1} # 参数设置 MAX_EPOCH = 25 BATCH_SIZE = 16 LR = 0.001 log_interval = 10 val_interval = 1 classes = 2 start_epoch = -1 lr_decay_step = 7 # ============================ step 1/5 数据 ============================ data_dir = os.path.join(BASEDIR, "..", "..", "data/hymenoptera_data") train_dir = os.path.join(data_dir, "train") valid_dir = os.path.join(data_dir, "val") norm_mean = [0.485, 0.456, 0.406] norm_std = [0.229, 0.224, 0.225] train_transform = transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(norm_mean, norm_std), ]) valid_transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(norm_mean, norm_std), ]) # 构建MyDataset实例 train_data = AntsDataset(data_dir=train_dir, transform=train_transform) valid_data = AntsDataset(data_dir=valid_dir, transform=valid_transform) # 构建DataLoder train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True) valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE) # ============================ step 2/5 模型 ============================ # 1/3 构建模型 resnet18_ft = models.resnet18() # 2/3 加载参数 # flag = 0 flag = 1 if flag: path_pretrained_model = os.path.join(BASEDIR, "..", "..", "data/resnet18-5c106cde.pth") state_dict_load = torch.load(path_pretrained_model) resnet18_ft.load_state_dict(state_dict_load) # 法1 : 冻结卷积层 flag_m1 = 0 # flag_m1 = 1 if flag_m1: for param in resnet18_ft.parameters(): param.requires_grad = False print("conv1.weights[0, 0, ...]:\n {}".format(resnet18_ft.conv1.weight[0, 0, ...])) # 3/3 替换fc层 num_ftrs = resnet18_ft.fc.in_features resnet18_ft.fc = nn.Linear(num_ftrs, classes) resnet18_ft.to(device) # ============================ step 3/5 损失函数 ============================ criterion = nn.CrossEntropyLoss() # 选择损失函数 # ============================ step 4/5 优化器 ============================ # 法2 : conv 小学习率 # flag = 0 flag = 1 if flag: fc_params_id = list(map(id, resnet18_ft.fc.parameters())) # 返回的是parameters的 内存地址 base_params = filter(lambda p: id(p) not in fc_params_id, resnet18_ft.parameters()) optimizer = optim.SGD([ {'params': base_params, 'lr': LR * 0}, # 0 {'params': resnet18_ft.fc.parameters(), 'lr': LR}], momentum=0.9) else: optimizer = optim.SGD(resnet18_ft.parameters(), lr=LR, momentum=0.9) # 选择优化器 scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_decay_step, gamma=0.1) # 设置学习率下降策略 # ============================ step 5/5 训练 ============================ train_curve = list() valid_curve = list() for epoch in range(start_epoch + 1, MAX_EPOCH): loss_mean = 0. correct = 0. total = 0. resnet18_ft.train() for i, data in enumerate(train_loader): # forward inputs, labels = data inputs, labels = inputs.to(device), labels.to(device) outputs = resnet18_ft(inputs) # backward optimizer.zero_grad() loss = criterion(outputs, labels) loss.backward() # update weights optimizer.step() # 统计分类情况 _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).squeeze().cpu().sum().numpy() # 打印训练信息 loss_mean += loss.item() train_curve.append(loss.item()) if (i + 1) % log_interval == 0: loss_mean = loss_mean / log_interval print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format( epoch, MAX_EPOCH, i + 1, len(train_loader), loss_mean, correct / total)) loss_mean = 0. # if flag_m1: print("epoch:{} conv1.weights[0, 0, ...] :\n {}".format(epoch, resnet18_ft.conv1.weight[0, 0, ...])) scheduler.step() # 更新学习率 # validate the model if (epoch + 1) % val_interval == 0: correct_val = 0. total_val = 0. loss_val = 0. resnet18_ft.eval() with torch.no_grad(): for j, data in enumerate(valid_loader): inputs, labels = data inputs, labels = inputs.to(device), labels.to(device) outputs = resnet18_ft(inputs) loss = criterion(outputs, labels) _, predicted = torch.max(outputs.data, 1) total_val += labels.size(0) correct_val += (predicted == labels).squeeze().cpu().sum().numpy() loss_val += loss.item() loss_val_mean = loss_val / len(valid_loader) valid_curve.append(loss_val_mean) print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format( epoch, MAX_EPOCH, j + 1, len(valid_loader), loss_val_mean, correct_val / total_val)) resnet18_ft.train() train_x = range(len(train_curve)) train_y = train_curve train_iters = len(train_loader) valid_x = np.arange(1, len(valid_curve) + 1) * train_iters * val_interval # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations valid_y = valid_curve plt.plot(train_x, train_y, label='Train') plt.plot(valid_x, valid_y, label='Valid') plt.legend(loc='upper right') plt.ylabel('loss value') plt.xlabel('Iteration') plt.show()
“PyTorch frozen怎么使用”的内容就介绍到这里了,感谢大家的阅读。如果想了解更多行业相关的知识可以关注亿速云网站,小编将为大家输出更多高质量的实用文章!
免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:is@yisu.com进行举报,并提供相关证据,一经查实,将立刻删除涉嫌侵权内容。