PyTorch Examples
ResNet18
Write code for model training.
import argparse
import os
import shutil
import time
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torchvision.models.resnet import resnet18
from pytorch_nndct import Pruner
from pytorch_nndct import InputSpec
parser = argparse.ArgumentParser()
parser.add_argument(
'--data_dir',
default='/scratch/workspace/dataset/imagenet/pytorch',
help='Data set directory.')
parser.add_argument(
'--pretrained',
default='/scratch/workspace/wangyu/nndct_test_data/models/resnet18.pth',
help='Trained model file path.')
parser.add_argument(
'--ratio',
default=0.1,
type=float,
help='Desired pruning ratio. The larger this value, the smaller'
'the model after pruning.')
parser.add_argument(
'--ana',
default=False,
type=bool,
help='Whether to perform model analysis.')
args, _ = parser.parse_known_args()
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f'):
self.name = name
self.fmt = fmt
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, 'model_best.pth.tar')
class ProgressMeter(object):
def __init__(self, num_batches, meters, prefix=""):
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
self.meters = meters
self.prefix = prefix
def display(self, batch):
entries = [self.prefix + self.batch_fmtstr.format(batch)]
entries += [str(meter) for meter in self.meters]
print('\t'.join(entries))
def _get_batch_fmtstr(self, num_batches):
num_digits = len(str(num_batches // 1))
fmt = '{:' + str(num_digits) + 'd}'
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions
for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def adjust_learning_rate(optimizer, epoch, lr):
"""Sets the learning rate to the initial LR decayed by every 2 epochs"""
lr = lr * (0.1**(epoch // 2))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def train(train_loader, model, criterion, optimizer, epoch):
batch_time = AverageMeter('Time', ':6.3f')
data_time = AverageMeter('Data', ':6.3f')
losses = AverageMeter('Loss', ':.4e')
top1 = AverageMeter('Acc@1', ':6.2f')
top5 = AverageMeter('Acc@5', ':6.2f')
progress = ProgressMeter(
len(train_loader), [batch_time, data_time, losses, top1, top5],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
end = time.time()
for i, (images, target) in enumerate(train_loader):
# measure data loading time
data_time.update(time.time() - end)
model = model.cuda()
images = images.cuda()
target = target.cuda()
# compute output
output = model(images)
loss = criterion(output, target)
# measure accuracy and record loss
acc1, acc5 = accuracy(output, target, topk=(1, 5))
losses.update(loss.item(), images.size(0))
top1.update(acc1[0], images.size(0))
top5.update(acc5[0], images.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % 10 == 0:
progress.display(i)
def evaluate(val_loader, model, criterion):
batch_time = AverageMeter('Time', ':6.3f')
losses = AverageMeter('Loss', ':.4e')
top1 = AverageMeter('Acc@1', ':6.2f')
top5 = AverageMeter('Acc@5', ':6.2f')
progress = ProgressMeter(
len(val_loader), [batch_time, losses, top1, top5], prefix='Test: ')
# switch to evaluate mode
model.eval()
with torch.no_grad():
end = time.time()
for i, (images, target) in enumerate(val_loader):
model = model.cuda()
images = images.cuda(non_blocking=True)
target = target.cuda(non_blocking=True)
# compute output
output = model(images)
loss = criterion(output, target)
# measure accuracy and record loss
acc1, acc5 = accuracy(output, target, topk=(1, 5))
losses.update(loss.item(), images.size(0))
top1.update(acc1[0], images.size(0))
top5.update(acc5[0], images.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % 50 == 0:
progress.display(i)
# TODO: this should also be done with the ProgressMeter
print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(
top1=top1, top5=top5))
return top1.avg, top5.avg
Define a function used for model analysis.
def ana_eval_fn(model, val_loader, loss_fn):
return evaluate(val_loader, model, loss_fn)[1]
Create a resnet18 model and add pruning APIs to perform pruning.
if __name__ == '__main__':
model = resnet18().cpu()
model.load_state_dict(torch.load(args.pretrained))
batch_size = 128
workers = 4
traindir = os.path.join(args.data_dir, 'train')
valdir = os.path.join(args.data_dir, 'validation')
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_dataset = datasets.ImageFolder(
traindir,
transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]))
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=workers,
pin_memory=True)
val_dataset = datasets.ImageFolder(
valdir,
transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
]))
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=workers,
pin_memory=True)
criterion = torch.nn.CrossEntropyLoss().cuda()
inputs = torch.randn([1, 3, 224, 224], dtype=torch.float32)
pruner = Pruner(model, inputs)
if args.ana:
pruner.ana(ana_eval_fn, args=(val_loader, criterion), gpus=[0, 1, 2, 3])
model = pruner.prune(ratio=args.ratio)
pruner.summary(model)
lr = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr, weight_decay=1e-4)
best_acc5 = 0
epochs = 1
for epoch in range(epochs):
adjust_learning_rate(optimizer, epoch, lr)
train(train_loader, model, criterion, optimizer, epoch)
acc1, acc5 = evaluate(val_loader, model, criterion)
# remember best acc@1 and save checkpoint
is_best = acc5 > best_acc5
best_acc5 = max(acc5, best_acc5)
if is_best:
torch.save(model.pruned_state_dict(), 'resnet18_sparse.pth')
torch.save(model.state_dict(), 'resnet18_dense.pth')
Download pretrained ResNet18 model:
wget https://download.pytorch.org/models/resnet18-5c106cde.pth -O resnet18.pth
Prepare ImageNet dataset, see http://image-net.org/download-images.
Run the first round of pruning with model analysis:
$ python -u resnet18_pruning.py --data_dir imagenet_dir --pretrained resnet18.pth --ratio 0.1 --ana True
Starting from the second round, pruning no longer requires a model analysis, you just need to increase the pruning ratio and use the sparse checkpoint saved from previous round as the pretrained weights.
$ python -u resnet18_pruning.py --data_dir imagenet_dir --pretrained resnet18_sparse.pth.tar --ratio 0.2