掘金 人工智能 08月05日
基于ERNIE-4.5-0.3B医疗领域大模型一站式分布式训练部署
index_new5.html
../../../zaker_core/zaker_tpl_static/wap/tpl_guoji1.html

 

本文详细介绍了如何基于百度开源的ERNIE-4.5-0.3B模型,利用ERNIEKit框架进行医疗领域数据的SFT(监督式微调)和DPO(直接偏好优化)训练。文章从环境准备、模型下载、数据格式解析到具体的训练配置和执行,提供了全面的指导。同时,也展示了如何利用FastDeploy工具将训练完成的模型进行服务化部署,并通过API调用进行推理,以实现从模型训练到实际应用的高效衔接,为医疗大模型的落地应用提供了实用的技术路径。

🚀 **模型与训练框架介绍**:百度于2025年6月30日开源了文心大模型4.5系列,其中ERNIE-4.5-0.3B是适用于特定任务的轻量级稠密模型。本文以该模型为基础,结合ERNIEKit框架,旨在通过SFT和DPO方法,在医疗领域数据集上进行微调,以提升模型在医疗问答等任务上的表现。

🛠️ **环境搭建与模型获取**:文章详细列出了使用ERNIEKit进行模型训练所需的环境准备步骤,包括克隆项目仓库、安装依赖库(如FastDeploy)等。同时,提供了模型权重的下载方式,推荐使用aistudio_sdk以获得稳定高效的下载体验,确保了训练过程的顺利进行。

📊 **数据准备与格式解析**:为确保微调效果,文章重点介绍了SFT和DPO训练所需的数据格式。SFT数据格式包含`src`(用户提问)和`tgt`(模型回答),支持多轮对话。DPO数据格式则在此基础上增加了`response`字段,包含`chosen`(偏好回答)和`rejected`(非偏好回答)两类,并用`sort`字段区分,这为模型学习人类偏好提供了关键依据。

⚙️ **SFT与DPO训练配置与执行**:文章详细解析了SFT和DPO训练的YAML配置文件中的关键参数,如数据集路径、模型名称、微调方式(Full)、训练轮次、学习率、梯度累积步数以及多种性能优化设置(如`recompute`、`amp`等)。并提供了直接执行训练的命令,以及绘制训练损失曲线的方法,便于监控训练过程。

🚀 **模型部署与服务化**:训练完成后,文章介绍了如何利用FastDeploy工具将模型进行服务化部署。通过`erniekit server`命令启动推理服务,并提供了Python代码示例,展示了如何通过HTTP请求调用部署好的模型进行医疗问答的推理,实现了模型从训练到实际应用的关键环节。

基于ERNIE-4.5-0.3B医疗领域大模型一站式分布式训练部署

1.简介

2025年6月30日,百度正式开源文心大模型4.5系列,全面覆盖从具备47B激活参数的混合专家(MoE)模型,到轻量级的0.3B稠密模型,支持文本生成、多模态理解等多种任务场景。

ERNIE 4.5 Models Model Information
Model Category Model Input Modality Output Modality Context Window
Large Language Models (LLMs) ERNIE-4.5-300B-A47B-Base Text Text 128K
ERNIE-4.5-300B-A47B
ERNIE-4.5-21B-A3B-Base
ERNIE-4.5-21B-A3B
Vision-Language Models (VLMs) ERNIE-4.5-VL-424B-A47B-Base Text/Image/Video Text
ERNIE-4.5-VL-424B-A47B
ERNIE-4.5-VL-28B-A3B-Base
ERNIE-4.5-VL-28B-A3B
Dense Models ERNIE-4.5-0.3B-Base Text Text
ERNIE-4.5-0.3B

2.项目简介

使用健康数据集,基于ERNIE 0.3B模型进行微调

3.训练环境准备

# 1. 克隆 ERNIEKit 项目仓库!git clone https://gitee.com/hqu_ljc/ERNIE.git# 2. 安装 ERNIEKit 依赖%cd ERNIE!pip install -r requirements/gpu/requirements.txt!pip install -e . # 推荐使用可编辑模式安装# 3.安装FastDeploy!pip install https://paddle-whl.bj.bcebos.com/stable/fastdeploy-gpu-80_90/fastdeploy-gpu/fastdeploy_gpu-2.0.0-py3-none-any.whl

4.模型下载

模型权重已开源,可在Huggingface Hub、AiStudio、ModelScope等多个平台下载。在本项目中项目推荐使用aistudio_sdk下载模型,以获得稳定高效的体验。

!aistudio download --model PaddlePaddle/ERNIE-4.5-0.3B-Paddle --local_dir baidu/ERNIE-4.5-0.3B-Paddle

5.数据准备

ERNIEKit中SFT训练支持erniekitalpaca格式训练数据,DPO支持erniekit训练数据,更多训练格式细节可以参考文档 ERNIEKit训练数据介绍

5.1 疗领域数据集

本项目实战营提供了erniekit格式的医疗领域问答数据集,路径如下:

/home/aistudio/data/data351566/├── train-sft.jsonl        # SFT训练数据集├── val-sft.jsonl        # SFT评估数据集├── train-dpo.jsonl       # DPO训练数据集└── val-dpo.jsonl        # DPO评估数据集

5.2 SFT数据格式

字段名是否必需类型说明
system可选string系统设置,设定模型角色和语境
src必需string[]用户提问内容
tgt必需string[]模型生成内容
label可选list长度与对话轮数一致,1代表本轮对话训练,0代表本轮对话不参与训练
{    "src": ["高甘油三酯血症的就诊科室是什么?"],    "tgt": ["内科;内分泌科"],}

5.3 DPO数据格式

字段名是否必需类型说明
system可选string系统设置,设定模型角色和语境
src必需string[]用户提问内容
tgt必需string[]模型生成内容,比src少一轮
response必需list包含chosen(偏好回答)和rejected(非偏好回答)对话,需要包含奇数轮
sort必需list区分 chosen/rejected (0=rejected, 1=chosen)

Notes:

{    "src": ["骨纤维异常增殖症的就诊科室是什么?"],     "tgt": [],     "response": [      ["外科;骨外科"],       ["骨纤维异常增生症是一种常见的骨骼疾病。"]    ],   "sort": [1, 0]}

6.SFT训练

SFT,全称为 Supervised Fine-Tuning(监督式微调),是大语言模型(LLM)训练流程中的关键步骤之一,通常用于指令微调(Instruction Tuning)阶段,使模型学会根据用户输入生成符合预期的响应。它是对预训练模型在特定任务或风格下进行有标签数据的监督学习微调。

6.1 配置

### datatrain_dataset_type: "erniekit"eval_dataset_type: "erniekit"train_dataset_path: "/home/aistudio/data/data351566/train-sft.jsonl"train_dataset_prob: "1.0"eval_dataset_path: "/home/aistudio/data/data351566/val-sft.jsonl"eval_dataset_prob: "1.0"max_seq_len: 8192num_samples_each_epoch: 6000000### modelmodel_name_or_path: baidu/ERNIE-4.5-0.3B-Paddlefine_tuning: Fullfuse_rope: Trueuse_sparse_head_and_loss_fn: True### finetuning# basestage: SFTseed: 23do_train: Truedo_eval: Truedistributed_dataloader: Falsedataloader_num_workers: 1batch_size: 2num_train_epochs: 1max_steps: 100max_evaluate_steps: 10000eval_steps: 100evaluation_strategy: stepssave_steps: 10000000save_total_limit: 5save_strategy: stepslogging_steps: 1release_grads: Truegradient_accumulation_steps: 16logging_dir: ./sft_vdl_logoutput_dir: ./output_sftdisable_tqdm: True# trainwarmup_steps: 20learning_rate: 1.0e-5lr_scheduler_type: cosinemin_lr: 1.0e-6layerwise_lr_decay_bound: 1.0# optimizerweight_decay: 0.1adam_epsilon: 1.0e-8adam_beta1: 0.9adam_beta2: 0.95offload_optim: True# performancetensor_parallel_degree: 1pipeline_parallel_degree: 1sharding_parallel_degree: 1sharding: stage1sequence_parallel: Truepipeline_parallel_config: enable_delay_scale_loss enable_release_grads disable_partial_send_recvrecompute: Falserecompute_use_reentrant: Truecompute_type: bf16fp16_opt_level: O2disable_ckpt_quant: Trueamp_master_grad: Trueamp_custom_white_list:  - lookup_table  - lookup_table_v2  - flash_attn  - matmul  - matmul_v2  - fused_gemm_epilogueamp_custom_black_list:  - reduce_sum  - softmax_with_cross_entropy  - c_softmax_with_cross_entropy  - elementwise_div  - sin  - cosunified_checkpoint: Trueunified_checkpoint_config: async_save

6.2 训练

!erniekit train /home/aistudio/work/run_sft_8k.yaml

6.3 Loss

## SFT训练损失曲线绘制import osfrom visualdl import LogReaderimport matplotlib.pyplot as pltoutput = "/home/aistudio/ERNIE/sft_vdl_log"log_files = [f for f in os.listdir(output) if f.endswith(".log")]if len(log_files) > 0:    reader = LogReader(file_path=os.path.join(output,log_files[-1]))    data = reader.get_data('scalar', 'train/loss')    train_loss = []    for i in range(len(data)):        train_loss.append(data[i].value)    steps = range(1, len(train_loss) + 1)    plt.figure(figsize=(10, 6))    plt.plot(steps, train_loss, 'b', label='Training loss')    plt.title('Training Loss')    plt.xlabel('Steps')    plt.ylabel('Loss')    plt.legend()    plt.grid(True)    plt.show()else:    print("No Log files! ")

6.4 训练参数解析

下面介绍了常用参数配置解析,更具体的训练参数配置可以参考ERNIEKit训练参数解析

参数配置值含义
train_dataset_type"erniekit"指定训练集格式为 ERNIEKit 的标准格式。
eval_dataset_type"erniekit"指定验证集格式为 ERNIEKit 的标准格式。
train_dataset_path"/home/aistudio/data/medical/sft1.jsonl"训练集文件路径。
eval_dataset_path"/home/aistudio/data/medical/sft1.jsonl"验证集文件的路径。
max_seq_len8192最大训练token数,这里使用packing数据流策略。
model_name_or_path"baidu/ERNIE-4.5-0.3B-Paddle"基础模型的文件路径。
fine_tuningFull核心参数:指定微调方式为 Full,即全量微调。
stageSFT指定训练阶段为监督微调(Supervised Fine-Tuning)。
do_train / do_evalTrue开启训练和评估流程。
per_device_train_batch_size1每张 GPU 上的单次训练样本数(微批次)。
gradient_accumulation_steps8梯度累积步数。有效批大小 = 1 × 8 = 8,用于模拟大批次训练。
num_train_epochs1总的训练轮次,表示整个数据集将被完整训练 1 次。
learning_rate1.0e-5学习率。全量微调时比 LoRA 要小,确保稳定收敛。

7.DPO训练

DPO,全称为 Direct Preference Optimization(直接偏好优化),是一种无需强化学习(如 PPO)的、用来训练大语言模型更符合人类偏好的新型训练方法,作为 RLHF(Reinforcement Learning with Human Feedback)中的 PPO 替代方案,训练更稳定,结构更简单。

7.1开始训练

!erniekit train /home/aistudio/work/run_dpo_8k.yaml

7.2 Train Loss

8.模型部署

模型训练完成后,真正的挑战是如何高效部署并投入使用。FastDeploy 是飞桨体系下专为大模型打造的推理部署工具,提供轻量、灵活的多端部署能力,覆盖从边缘设备到云端的多种场景。它支持主流接口标准(如 OpenAI API),让模型快速接入实际业务,打通训练到应用的关键一环。

8.1 服务化推理

使用erniekit快速推理模型

import subprocessimport threadingimport sysdef run_command(cmd):    # 启动子进程,捕获stdout和stderr    process = subprocess.Popen(cmd, shell=True,                               stdout=subprocess.PIPE,                               stderr=subprocess.STDOUT,                              universal_newlines=True)        # 实时打印输出    for line in process.stdout:        sys.stdout.write(line)        sys.stdout.flush()        # 等待进程结束    process.wait()cmd = "erniekit server /home/aistudio/work/run_chat.yaml"t = threading.Thread(target=run_command, args=(cmd,))t.start()
# 等上面启动服务后运行下面的对话测试代码import requestsimport jsonurl = "http://0.0.0.0:8188/v1/chat/completions"headers = {"Content-Type": "application/json"}data = {    "messages": [        {            "role": "user",            "content": "皮肤每天都过敏,胳膊、大腿起红包,吃什么药?"        }    ]}try:    response = requests.post(url, headers=headers, json=data)    response.raise_for_status()  # 检查请求是否成功        # 获取响应数据    result = response.json()    print("=== 完整响应 ===")    print(json.dumps(result, indent=2, ensure_ascii=False))        # 提取并格式化对话内容    print("\n=== 对话内容 ===")    # 打印用户输入的问题    user_msg = data["messages"][0]    print(f"[User]: {user_msg['content']}")    if "choices" in result and len(result["choices"]) > 0:        for choice in result["choices"]:            if "message" in choice:                msg = choice["message"]                print(f"[{msg['role']}]: {msg['content']}\n")    except requests.exceptions.RequestException as e:    print(f"请求出错: {e}")except ValueError as e:    print(f"JSON解析出错: {e}")
=== 完整响应 ==={  "id": "chatcmpl-43f70e43-f109-424e-b74c-6dfdedc3b3e6",  "object": "chat.completion",  "created": 1754403559,  "model": "default",  "choices": [    {      "index": 0,      "message": {        "role": "assistant",        "content": "这种情况可以考虑口服氯雷他定治疗,或外用应用氢化胱氨酸针刺激红肿的部位,现代医学体外照射雷管,吸出肛门周围毛细血管,但不能排除雌激素射管是否引来勃起没办法,建议巧妙使用神经性皮炎万通iveru中药内服调理。对于生活护理:有的病人发作期在网络上去获取关于自我的知识和自我调剂,了解自身情绪,能够调解紧张的心情,放松精神。病人生活中习惯规律,饮食有度,适当的多运动,也有润出来的创伤。",        "reasoning_content": null,        "tool_calls": null      },      "finish_reason": "stop"    }  ],  "usage": {    "prompt_tokens": 20,    "total_tokens": 130,    "completion_tokens": 110,    "prompt_tokens_details": {      "cached_tokens": 0    }  }}=== 对话内容 ===[User]: 皮肤每天都过敏,胳膊、大腿起红包,吃什么药?[assistant]: 这种情况可以考虑口服氯雷他定治疗,或外用应用氢化胱氨酸针刺激红肿的部位,现代医学体外照射雷管,吸出肛门周围毛细血管,但不能排除雌激素射管是否引来勃起没办法,建议巧妙使用神经性皮炎万通iveru中药内服调理。对于生活护理:有的病人发作期在网络上去获取关于自我的知识和自我调剂,了解自身情绪,能够调解紧张的心情,放松精神。病人生活中习惯规律,饮食有度,适当的多运动,也有润出来的创伤。

8.2参数介绍

参数配置值含义
model_name_or_path"/home/aistudio/ERNIE/output_dpo"模型文件路径。
tensor_parallel_degree1我们只有单卡设为1。
max_model_len8192server端参数,模型推理支持最长长度
port8188server端参数,server端口号
max_new_tokens1024client端参数,最大生成token数。
top_p0.7client端参数,topP采样策略参数。
temperature0.95client端参数,温度参数。

Fish AI Reader

Fish AI Reader

AI辅助创作,多种专业模板,深度分析,高质量内容生成。从观点提取到深度思考,FishAI为您提供全方位的创作支持。新版本引入自定义参数,让您的创作更加个性化和精准。

FishAI

FishAI

鱼阅,AI 时代的下一个智能信息助手,助你摆脱信息焦虑

联系邮箱 441953276@qq.com

相关标签

ERNIE-4.5 医疗大模型 ERNIEKit SFT DPO 模型部署 FastDeploy
相关文章