博客
关于我
知识蒸馏DEiT算法实战:使用RegNet蒸馏DEiT模型
阅读量:458 次
发布时间:2019-03-06

本文共 12942 字,大约阅读时间需要 43 分钟。

文章目录

摘要

本文介绍了如何通过外部蒸馏算法对DeiT模型进行知识蒸馏。文中主要探讨了两种蒸馏方式,并详细讲解了如何实现第二种方法,即通过卷积神经网络蒸馏DEiT模型。


模型和损失

model.py代码

# Copyright (c) 2015-present, Facebook, Inc. All rights reserved.
import torch
import torch.nn as nn
from functools import partial
from timm.models.vision_transformer import VisionTransformer, _cfg
from timm.models.registry import register_model
from timm.models.layers import trunc_normal
# 注册模型
@register_model
def deit_tiny_patch16_224(pretrained=False, **kwargs):
model = VisionTransformer(
patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs
)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth",
map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
return model

losses.py代码

# Copyright (c) 2015-present, Facebook, Inc. All rights reserved.
"""Implements the knowledge distillation loss"""
import torch
from torch.nn import functional as F
class DistillationLoss(torch.nn.Module):
"""This module wraps a standard criterion and adds an extra knowledge distillation loss by
taking a teacher model prediction and using it as additional supervision."""
def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module,
distillation_type: str, alpha: float, tau: float):
super().__init__()
self.base_criterion = base_criterion
self.teacher_model = teacher_model
assert distillation_type in ['none', 'soft', 'hard']
self.distillation_type = distillation_type
self.alpha = alpha
self.tau = tau
def forward(self, inputs, outputs, labels):
"""Args:
inputs: The original inputs that are feed to the teacher model
outputs: the outputs of the model to be trained. It is expected to be
either a Tensor, or a Tuple[Tensor, Tensor], with the original output
in the first position and the distillation predictions as the second output
labels: the labels for the base criterion
"""
outputs_kd = None
if not isinstance(outputs, torch.Tensor):
outputs, outputs_kd = outputs
base_loss = self.base_criterion(outputs, labels)
if self.distillation_type == 'none':
return base_loss
if outputs_kd is None:
raise ValueError("When knowledge distillation is enabled, the model is "
"expected to return a Tuple[Tensor, Tensor] with the output of the "
"class_token and the dist_token")
# don't backprop through the teacher
with torch.no_grad():
teacher_outputs = self.teacher_model(inputs)
if self.distillation_type == 'soft':
T = self.tau
distillation_loss = F.kl_div(
F.log_softmax(outputs_kd / T, dim=1),
F.log_softmax(teacher_outputs / T, dim=1),
reduction='sum', log_target=True
) * (T * T) / outputs_kd.numel()
elif self.distillation_type == 'hard':
distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1))
loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha
return loss

训练Teacher模型

teacher_train.py代码

# 导入必要的库
import torch.optim as optim
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
from torchvision import datasets
from torch.autograd import Variable
from timm.models import regnetx_160
import json
import os
# 定义训练过程
def seed_everything(seed=42):
os.environ['PYTHONHASHSEED'] = str(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
# 训练函数
def train(model, device, train_loader, optimizer, epoch):
model.train()
sum_loss = 0
total_num = len(train_loader.dataset)
print(total_num, len(train_loader))
for batch_idx, (data, target) in enumerate(train_loader):
data, target = Variable(data).to(device), Variable(target).to(device)
out = model(data)
loss = criterion(out, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print_loss = loss.data.item()
sum_loss += print_loss
if (batch_idx + 1) % 10 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),
100. * (batch_idx + 1) / len(train_loader), loss.item()))
ave_loss = sum_loss / len(train_loader)
print('epoch:{},loss:{}'.format(epoch, ave_loss))
# 验证过程
@torch.no_grad()
def val(model, device, test_loader):
global Best_ACC
model.eval()
test_loss = 0
correct = 0
total_num = len(test_loader.dataset)
print(total_num, len(test_loader))
with torch.no_grad():
for data, target in test_loader:
data, target = Variable(data).to(device), Variable(target).to(device)
out = model(data)
loss = criterion(out, target)
_, pred = torch.max(out.data, 1)
correct += torch.sum(pred == target)
print_loss = loss.data.item()
test_loss += print_loss
correct = correct.data.item()
acc = correct / total_num
avgloss = test_loss / len(test_loader)
if acc > Best_ACC:
torch.save(model, file_dir + '/' + 'best.pth')
Best_ACC = acc
print('\nVal set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
avgloss, correct, len(test_loader.dataset), 100 * acc))
return acc

全局参数

if __name__ == '__main__':
# 创建保存模型的文件夹
file_dir = 'TeacherModel'
if os.path.exists(file_dir):
print('true')
os.makedirs(file_dir, exist_ok=True)
else:
os.makedirs(file_dir)
# 设置全局参数
modellr = 1e-4
BATCH_SIZE = 16
EPOCHS = 100
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
SEED = 42
seed_everything(SEED)

学生网络

student_train.py代码

# 导入必要的库
import torch.optim as optim
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
from torchvision import datasets
from models.models import deit_tiny_distilled_patch16_224
import json
import os
# 定义训练过程
def seed_everything(seed=42):
os.environ['PYTHONHASHSEED'] = str(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
# 定义训练过程
def train(model, device, train_loader, optimizer, epoch):
model.train()
sum_loss = 0
total_num = len(train_loader.dataset)
print(total_num, len(train_loader))
for batch_idx, (data, target) in enumerate(train_loader):
data, target = Variable(data).to(device), Variable(target).to(device)
out = model(data)[0]
loss = criterion(out, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print_loss = loss.data.item()
sum_loss += print_loss
if (batch_idx + 1) % 10 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),
100. * (batch_idx + 1) / len(train_loader), loss.item()))
ave_loss = sum_loss / len(train_loader)
print('epoch:{},loss:{}'.format(epoch, ave_loss))
# 验证过程
@torch.no_grad()
def val(model, device, test_loader):
global Best_ACC
model.eval()
test_loss = 0
correct = 0
total_num = len(test_loader.dataset)
print(total_num, len(test_loader))
with torch.no_grad():
for data, target in test_loader:
data, target = Variable(data).to(device), Variable(target).to(device)
out = model(data)
loss = criterion(out, target)
_, pred = torch.max(out.data, 1)
correct += torch.sum(pred == target)
print_loss = loss.data.item()
test_loss += print_loss
correct = correct.data.item()
acc = correct / total_num
avgloss = test_loss / len(test_loader)
if acc > Best_ACC:
torch.save(model, file_dir + '/' + 'best.pth')
Best_ACC = acc
print('\nVal set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
avgloss, correct, len(test_loader.dataset), 100 * acc))
return acc

全局参数

if __name__ == '__main__':
# 创建保存模型的文件夹
file_dir = 'StudentModel'
if os.path.exists(file_dir):
print('true')
os.makedirs(file_dir, exist_ok=True)
else:
os.makedirs(file_dir)
# 设置全局参数
modellr = 1e-4
BATCH_SIZE = 16
EPOCHS = 100
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
SEED = 42
seed_everything(SEED)

蒸馏学生网络

train_kd.py代码

# 导入必要的库
import torch.optim as optim
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
from timm.loss import LabelSmoothingCrossEntropy
from torchvision import datasets
from models.models import deit_tiny_distilled_patch16_224
import json
import os
from losses import DistillationLoss
# 设置随机因子
def seed_everything(seed=42):
os.environ['PYTHONHASHSEED'] = str(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
# 定义训练过程
def train(s_net, t_net, device, criterionKD, train_loader, optimizer, epoch):
s_net.train()
sum_loss = 0
total_num = len(train_loader.dataset)
print(total_num, len(train_loader))
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
out_s = s_net(data)
loss = criterionKD(data, out_s, target)
loss.backward()
optimizer.step()
print_loss = loss.data.item()
sum_loss += print_loss
if (batch_idx + 1) % 10 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),
100. * (batch_idx + 1) / len(train_loader), loss.item()))
ave_loss = sum_loss / len(train_loader)
print('epoch:{},loss:{}'.format(epoch, ave_loss))
# 验证过程
@torch.no_grad()
def val(model, device, criterionCls, test_loader):
global Best_ACC
model.eval()
test_loss = 0
correct = 0
total_num = len(test_loader.dataset)
print(total_num, len(test_loader))
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
out_s = model(data)
loss = criterionCls(out_s, target)
_, pred = torch.max(out_s.data, 1)
correct += torch.sum(pred == target)
print_loss = loss.data.item()
test_loss += print_loss
correct = correct.data.item()
acc = correct / total_num
avgloss = test_loss / len(test_loader)
if acc > Best_ACC:
torch.save(model, file_dir + '/' + 'best.pth')
Best_ACC = acc
print('\nVal set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
avgloss, correct, len(test_loader.dataset), 100 * acc))
return acc

全局参数

if __name__ == '__main__':
# 创建保存模型的文件夹
file_dir = 'KDModel'
if os.path.exists(file_dir):
print('true')
os.makedirs(file_dir, exist_ok=True)
else:
os.makedirs(file_dir)
# 设置全局参数
modellr = 1e-4
BATCH_SIZE = 4
EPOCHS = 100
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
SEED = 42
seed_everything(SEED)
distillation_type = 'hard'
distillation_alpha = 0.5
distillation_tau = 1.0

结果比对

# 导入必要的库
import numpy as np
import matplotlib.pyplot as plt
# 定义文件路径
teacher_file = 'result.json'
student_file = 'result_student.json'
student_kd_file = 'result_kd.json'
# 定义读取函数
def read_json(file):
with open(file, 'r', encoding='utf8') as fp:
json_data = json.load(fp)
print(json_data)
return json_data
# 读取数据
teacher_data = read_json.teacher_file()
student_data = read_json.student_file()
student_kd_data = read_json.student_kd_file()
# 绘制图表
x = list(range(len(teacher_data)))
plt.plot(x, list(teacher_data.values()), label='teacher')
plt.plot(x, list(student_data.values()), label='student without IRG')
plt.plot(x, list(student_kd_data.values()), label='student with IRG')
plt.title('Test accuracy')
plt.legend()
plt.show()

总结

本文详细介绍了如何通过外部蒸馏算法对DeiT模型进行知识蒸馏。通过实验结果可以看出,采用硬蒸馏方式能够有效提升学生网络的性能。代码和数据集详见链接:链接

转载地址:http://sjdbz.baihongyu.com/

你可能感兴趣的文章
mysql where中如何判断不为空
查看>>
MySQL Workbench 使用手册:从入门到精通
查看>>
MySQL Workbench 数据库建模详解:从设计到实践
查看>>
MySQL Workbench 数据建模全解析:从基础到实践
查看>>
mysql workbench6.3.5_MySQL Workbench
查看>>
MySQL Workbench安装教程以及菜单汉化
查看>>
MySQL Xtrabackup 安装、备份、恢复
查看>>
mysql [Err] 1436 - Thread stack overrun: 129464 bytes used of a 286720 byte stack, and 160000 bytes
查看>>
MySQL _ MySQL常用操作
查看>>
MySQL – 导出数据成csv
查看>>
MySQL —— 在CentOS9下安装MySQL
查看>>
MySQL —— 视图
查看>>
mysql 不区分大小写
查看>>
mysql 两列互转
查看>>
MySQL 中开启二进制日志(Binlog)
查看>>
MySQL 中文问题
查看>>
MySQL 中日志的面试题总结
查看>>
mysql 中的all,5分钟了解MySQL5.7中union all用法的黑科技
查看>>
MySQL 中的外键检查设置:SET FOREIGN_KEY_CHECKS = 1
查看>>
Mysql 中的日期时间字符串查询
查看>>