1、模型蒸馏:
模型蒸馏(Model Distillation)是一种将复杂模型(教师模型)的知识迁移到更小、更高效的模型(学生模型) 的技术。其核心目的是在保持模型性能的同时,显著减少计算资源占用和推理时间,便于在边缘设备(如手机、IoT设备)上部署。本文的实例是使用bert-base-uncased模型蒸馏出distilbert-base-uncased,模型蒸馏的核心步骤包括:
训练教师模型:在大规模数据上训练一个高性能但复杂的模型(如BERT、ResNet)。
生成软标签:用教师模型对训练数据预测,得到概率分布(软标签)。
训练学生模型:学生模型同时学习:
- 软标签(通过KL散度损失函数)。真实标签(通过交叉熵损失)。
调整温度:高温训练,低温推理。温度参数 T>1时:概率分布更平滑,凸显次要类别信息。T=1时:标准softmax。 训练时使用较高的T,推理时恢复为T=1。
2、代码实例
首先定义一个dataset数据类:
class TextClassificationDataset(Dataset): def __init__(self, texts, labels, tokenizer, max_length): self.texts = texts #文本内容 self.labels = labels #文本对应的标签 self.tokenizer = tokenizer #token解析器 self.max_length = max_length def __len__(self): return len(self.texts) def __getitem__(self, idx): text = self.texts[idx] label = self.labels[idx] encoding = self.tokenizer( text, truncation=True, padding='max_length', max_length=self.max_length, return_tensors='pt' ) item = { 'input_ids': encoding['input_ids'].flatten(), 'attention_mask': encoding['attention_mask'].flatten(), 'labels': torch.tensor(label, dtype=torch.long) #确保label是一个张量 } return item3、准备训练数据和测试数据:
# 示例数据 - 情感分析 (0: 负面, 1: 正面)texts = [ "这部电影太棒了,演员表演出色!", "完全浪费时间和金钱。", "剧情一般,但特效还不错。", "强烈推荐,今年最好的电影之一!", "糟糕的导演和剧本,令人失望。", "演员阵容强大,但故事缺乏深度。", "从头到尾都吸引人,毫无冷场。", "摄影很美,但情节太 predictable。"]labels = [1, 0, 1, 1, 0, 0, 1, 0] # 测试数据test_texts = [ "不算太好,但也不差", "绝对 masterpiece,完美无缺" ]test_labels = [0, 1]4、定义模型蒸馏时需要的参数配置类,定义教师模型和学生模型:
class Config: #这两个模型可以自行下载,下载地址为git clone https://hf-mirror.com/google-bert/bert-base-uncased 和git clone https://hf-mirror.com/distilbert/distilbert-base-uncased,确保电脑上安装了lfs teacher_model_name = "bert-base-uncased的本地路径" student_model_name = "distilbert-base-uncased的本地路径" number_labels = 2 batch_size = 2 learning_rate = 5e-5#学习率 num_epochs = 10 max_length = 64 temperature = 2 #温度参数,控制软标签的平滑程度 alpha = 0.5 # 知识蒸馏的权重系数 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") config = Config()tokenizer = BertTokenizer.from_pretrained(config.teacher_model_name)train_dataset = TextClassificationDataset(texts, labels, tokenizer, config.max_length)train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)test_dataset = TextClassificationDataset(test_texts, test_labels, tokenizer, config.max_length)test_loader = DataLoader(test_dataset, batch_size=1)teacher_model = BertForSequenceClassification.from_pretrained(config.teacher_model_name, num_labels=config.number_labels).to(config.device)student_model = BertForSequenceClassification.from_pretrained(config.student_model_name, num_labels=config.number_labels).to(config.device)for param in teacher_model.parameters(): param.requires_grad = False # 冻结教师模型参数optimizer = torch.optim.AdamW(student_model.parameters(), lr=config.learning_rate)先加载预训练的模型,然后冻结教师模型的各项参数,定义优化器。
5、定义损失函数
def distill_loss(student_logits, teacher_logits, labels, temperature, alpha): """ 计算知识蒸馏损失 :param student_logits: 学生模型的输出 :param teacher_logits: 教师模型的输出 :param labels: 真实标签 :param temperature: 温度参数 :param alpha: 知识蒸馏的权重系数 :return: 损失值 """ soft_loss = torch.nn.KLDivLoss(reduction='batchmean')(torch.log_softmax(student_logits/temperature, dim=1), torch.softmax(teacher_logits/temperature, dim=1))*(temperature**2) hard_loss = torch.nn.CrossEntropyLoss()(student_logits, labels) return alpha * soft_loss + (1 - alpha) * hard_loss6、训练、蒸馏模型,并进行评估:
def train(model,data_loader, optimizer): model.train() total_loss = 0 for batch in tqdm(data_loader, desc="Training"): input_ids = batch['input_ids'].to(config.device) attention_mask = batch['attention_mask'].to(config.device) labels = batch['labels'].to(config.device) optimizer.zero_grad() with torch.no_grad(): teacher_outputs = teacher_model(input_ids, attention_mask=attention_mask) teacher_logits = teacher_outputs.logits student_outputs = model(input_ids, attention_mask=attention_mask) student_logits = student_outputs.logits loss = distill_loss(student_logits, teacher_logits, labels, config.temperature, config.alpha) optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() return total_loss/len(data_loader)def evaluate(model, data_loader): model.eval() correct = 0 total = 0 with torch.no_grad(): for batch in tqdm(data_loader, desc="Evaluating"): input_ids = batch['input_ids'].to(config.device) attention_mask = batch['attention_mask'].to(config.device) labels = batch['labels'].to(config.device) outputs = model(input_ids, attention_mask=attention_mask) logits = outputs.logits _, predicted = torch.max(logits, dim=1) total += labels.size(0) correct += (predicted == labels).sum().item() return correct/total7、调用、训练并开始评估:
for epoch in range(config.num_epochs): print(f"\nEpoch {epoch + 1}/{config.num_epochs}") # 训练 train_loss = train(student_model, train_loader, optimizer) print(f"Train Loss: {train_loss:.4f}") # 评估 accuracy = evaluate(student_model, test_loader) print(f"Test Accuracy: {accuracy:.2f}")8、使用optuna框架寻找最优的蒸馏参数:
ef objective(trial): params = { 'temperature': trial.suggest_float('temperature', 1.0, 15.0), 'alpha': trial.suggest_float('alpha', 0.1, 0.9), 'learning_rate': trial.suggest_float('learning_rate', 1e-6, 5e-5, log=True), 'num_epochs': 5, } student_model = DistilBertForSequenceClassification.from_pretrained( config.student_model_name, num_labels=config.number_labels) student_model.to(config.device) optimizer = torch.optim.AdamW(student_model.parameters(), lr=params['learning_rate']) best_accuracy = 0.0 for epoch in range(params['num_epochs']): student_model.train() for batch in train_loader: input_ids = batch['input_ids'].to(config.device) attention_mask = batch['attention_mask'].to(config.device) labels = batch['labels'].to(config.device) optimizer.zero_grad() with torch.no_grad(): teacher_outputs = teacher_model(input_ids, attention_mask=attention_mask) teacher_logits = teacher_outputs.logits student_outputs = student_model(input_ids, attention_mask=attention_mask) student_logits = student_outputs.logits loss = distill_loss(student_logits, teacher_logits, labels, params['temperature'], params['alpha']) optimizer.zero_grad() loss.backward() optimizer.step() accuracys = evaluate(student_model, test_loader) trial.report(accuracys, epoch) if trial.should_prune(): raise optuna.TrialPruned() if accuracys > best_accuracy: best_accuracy = accuracys return best_accuracy # 创建Optuna研究study = optuna.create_study( direction='maximize', # 我们要最大化准确率 sampler=optuna.samplers.TPESampler(), # 使用TPE采样器 pruner=optuna.pruners.MedianPruner() # 中值剪枝器,用于提前停止不理想的试验)# 运行优化study.optimize(objective, n_trials=20, timeout=600) # 最多20次试验或10分钟# 输出最佳结果print("Number of finished trials: ", len(study.trials))print("Best trial:")trial = study.best_trialprint(f" Value (Accuracy): {trial.value}")print(" Params: ")for key, value in trial.params.items(): print(f" {key}: {value}")def best_params_train(): best_params = study.best_params final_model = DistilBertForSequenceClassification.from_pretrained( config.student_model_name, num_labels=config.number_labels ).to(config.device) optimizer = torch.optim.AdamW(final_model.parameters(), lr=best_params['learning_rate']) for epoch in range(5): final_model.train() total_loss = 0 for batch in tqdm(train_loader, desc=f"Final Training:{epoch + 1}"): input_ids = batch['input_ids'].to(config.device) attention_mask = batch['attention_mask'].to(config.device) labels = batch['labels'].to(config.device) # optimizer.zero_grad() with torch.no_grad(): teacher_outputs = teacher_model(input_ids, attention_mask=attention_mask) teacher_logits = teacher_outputs.logits student_outputs = final_model(input_ids, attention_mask=attention_mask) student_logits = student_outputs.logits loss = distill_loss(student_logits, teacher_logits, labels, best_params['temperature'], best_params['alpha']) optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() # 每个epoch后评估 accuracy = evaluate(final_model, test_loader) print(f"Epoch {epoch + 1} - Loss: {total_loss / len(train_loader):.4f}, Accuracy: {accuracy:.4f}") # 保存最终模型 final_model.save_pretrained('optimized_distilled_distilbert')
