philschmid RSS feed 09月30日 19:11
Gemma微调教程:使用ChatML和Hugging Face TRL
index_new5.html
../../../zaker_core/zaker_tpl_static/wap/tpl_guoji1.html

 

本文介绍了如何使用Hugging Face TRL和ChatML格式微调Google的Gemma开源大语言模型。文章详细指导了开发环境设置、数据集准备、模型微调以及测试评估等步骤。通过QLoRA技术降低内存占用,并展示了在消费级GPU上运行的实例。内容涵盖安装依赖、加载Gemma模型、数据处理、训练参数配置和推理测试,适合希望优化LLM部署的开发者参考。

📊 本文以Gemma 7B模型为例,详细介绍了使用Hugging Face TRL和ChatML格式进行微调的完整流程,包括开发环境配置、数据集准备、模型加载、参数设置和训练评估等关键环节,为开发者提供了可复制的微调实践方案。

🔧 在开发环境设置方面,教程指导了Hugging Face核心库的安装,包括transformers、datasets、accelerate等,并特别介绍了Flash Attention技术如何通过优化注意力计算显著提升训练速度(最高加速3倍)及内存效率,同时强调了GPU架构兼容性要求。

📈 数据集准备环节重点展示了如何使用Databricks Dolly数据集进行微调,该数据集已预格式化为符合ChatML规范的对话形式(包含system、user、assistant角色),且最新版trl库支持直接加载使用无需额外格式化,简化了数据处理流程。

🚀 模型微调部分详细说明了SFTTrainer的参数配置,特别是如何结合PEFT(Parameter-Efficient Fine-Tuning)中的QLoRA技术,通过仅训练adapter层来大幅降低内存占用(约250,000词表大小),同时保持性能,并提供了lora_alpha、dropout等关键参数的设置依据。

🔍 测试评估阶段通过加载微调后的模型,使用pipeline进行文本生成推理测试,并采用简单循环和准确率作为评估指标,展示了模型在实际任务中的表现,包括问答、代码生成和数学计算等多样化应用场景。

Last week, Google released Gemma, a new family of state-of-the-art open LLMs. Gemma comes in two sizes: 7B parameters, for efficient deployment and development on consumer-size GPU and TPU and 2B versions for CPU and on-device applications. Both come in base and instruction-tuned variants.

After the first week it seemed that Gemma is not very friendly to fine-tune using the ChatML format, which is adapted and used by the open soruce community, e.g. OpenHermes or Dolphin. I created this blog post to show you how to fine-tune Gemma using ChatML and Hugging Face TRL.

This blog post is derived from my How to Fine-Tune LLMs in 2024 with Hugging Face blog tailored to fine-tune Gemma 7B. We will use Hugging Face TRL, Transformers & datasets.

    Setup development environmentCreate and prepare the datasetFine-tune LLM using trl and the SFTTrainerTest and evaluate the LLM

Note: This blog was created to run on consumer size GPUs (24GB), e.g. NVIDIA A10G or RTX 4090/3090, but can be easily adapted to run on bigger GPUs.

1. Setup development environment

Our first step is to install Hugging Face Libraries and Pyroch, including trl, transformers and datasets. If you haven't heard of trl yet, don't worry. It is a new library on top of transformers and datasets, which makes it easier to fine-tune, rlhf, align open LLMs.

# Install Pytorch & other libraries!pip install "torch==2.1.2" tensorboard # Install Hugging Face libraries!pip install  --upgrade \  "transformers==4.38.2" \  "datasets==2.16.1" \  "accelerate==0.26.1" \  "evaluate==0.4.1" \  "bitsandbytes==0.42.0" \  "trl==0.7.11" \  "peft==0.8.2"

If you are using a GPU with Ampere architecture (e.g. NVIDIA A10G or RTX 4090/3090) or newer you can use Flash attention. Flash Attention is a an method that reorders the attention computation and leverages classical techniques (tiling, recomputation) to significantly speed it up and reduce memory usage from quadratic to linear in sequence length. The TL;DR; accelerates training up to 3x. Learn more at FlashAttention.

Note: If your machine has less than 96GB of RAM and lots of CPU cores, reduce the number of MAX_JOBS. On the g5.2xlarge we used 4.

import torch; assert torch.cuda.get_device_capability()[0] >= 8, 'Hardware not supported for Flash Attention'# install flash-attn!pip install ninja packaging!MAX_JOBS=4 pip install flash-attn --no-build-isolation --upgrade

Installing flash attention from source can take quite a bit of time (10-45 minutes).

We will also need to login into our Hugging Face account to be able to access Gemma. To use Gemma you first need to agree to the terms of use. You can do this by visiting the Gemma page following the gate mechanism.

from huggingface_hub import login login(  token="", # ADD YOUR TOKEN HERE  add_to_git_credential=True) 

2. Create and prepare the dataset

We are not going to focus on creating a dataset in this blog post. If you want to learn more about creating a dataset, I recommend reading the How to Fine-Tune LLMs in 2024 with Hugging Face blog post. We are going to use the Databricks Dolly datatset, formated already as messages. This means we can use the conversational format to fine-tune our model.

{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}

The latest release of trl supports the conversation dataset formats. This means don't need to do any additional formatting of the dataset. We can use the dataset as is.

from datasets import load_dataset # Load Dolly Dataset.dataset = load_dataset("philschmid/dolly-15k-oai-style", split="train") print(dataset[3]["messages"])

3. Fine-tune LLM using trl and the SFTTrainer

We will use the SFTTrainer from trl to fine-tune our model. The SFTTrainer makes it straightfoward to supervise fine-tune open LLMs. The SFTTrainer is a subclass of the Trainer from the transformers library and supports all the same features, including logging, evaluation, and checkpointing, but adds additiional quality of life features, including:

    Dataset formatting, including conversational and instruction formatTraining on completions only, ignoring promptsPacking datasets for more efficient trainingPEFT (parameter-efficient fine-tuning) support including Q-LoRAPreparing the model and tokenizer for conversational fine-tuning (e.g. adding special tokens)

We will use the dataset formatting, packing and PEFT features in our example. As peft method we will use QLoRA a technique to reduce the memory footprint of large language models during finetuning, without sacrificing performance by using quantization.

Note: Gemma comes with a big vocabulary of ~250,000 tokens. Normally if you want to fine-tune LLMs on the ChatML format you would need to add special tokens to the tokenizer and model and teach to understand the different roles in a conversation. But Google included ~100 placeholder tokens in the vocabulary, which we can replace with special tokens, like <|im_start|> and <|im_end|>. I created a Tokenizer for the ChatML format philschmid/gemma-tokenizer-chatml which you can use to fine-tune Gemma with ChatML.

The Chat template used during fine-tuning is not 100% compatible with the ChatML format. Since Google/gemma-7b requires inputs always to start with a <bos> token. This means our inputs will look like.

<bos><|im_start|>systemYou are Gemma.<|im_end|><|im_start|>userHello, how are you?<|im_end|><|im_start|>assistantI'm doing great. How can I help you today?<|im_end|>\n<eos>

Note: We are not having an idea why Gemma needs <bos> token at the beginning of the input.

import torchfrom transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig # Hugging Face model idmodel_id = "google/gemma-7b"tokenizer_id = "philschmid/gemma-tokenizer-chatml" # BitsAndBytesConfig int-4 configbnb_config = BitsAndBytesConfig(    load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16) # Load model and tokenizermodel = AutoModelForCausalLM.from_pretrained(    model_id,    device_map="auto",    attn_implementation="flash_attention_2",    torch_dtype=torch.bfloat16,    quantization_config=bnb_config)tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)tokenizer.padding_side = 'right' # to prevent warnings

The SFTTrainer  supports a native integration with peft, which makes it super easy to efficiently tune LLMs using, e.g. QLoRA. We only need to create our LoraConfig and provide it to the trainer.

from peft import LoraConfig # LoRA config based on QLoRA paper & Sebastian Raschka experimentpeft_config = LoraConfig(        lora_alpha=8,        lora_dropout=0.05,        r=6,        bias="none",        target_modules="all-linear",        task_type="CAUSAL_LM",)

Before we can start our training we need to define the hyperparameters (TrainingArguments) we want to use.

from transformers import TrainingArguments args = TrainingArguments(    output_dir="gemma-7b-dolly-chatml", # directory to save and repository id    num_train_epochs=3,                     # number of training epochs    per_device_train_batch_size=2,          # batch size per device during training    gradient_accumulation_steps=2,          # number of steps before performing a backward/update pass    gradient_checkpointing=True,            # use gradient checkpointing to save memory    optim="adamw_torch_fused",              # use fused adamw optimizer    logging_steps=10,                       # log every 10 steps    save_strategy="epoch",                  # save checkpoint every epoch    bf16=True,                              # use bfloat16 precision    tf32=True,                              # use tf32 precision    learning_rate=2e-4,                     # learning rate, based on QLoRA paper    max_grad_norm=0.3,                      # max gradient norm based on QLoRA paper    warmup_ratio=0.03,                      # warmup ratio based on QLoRA paper    lr_scheduler_type="constant",           # use constant learning rate scheduler    push_to_hub=False,                       # push model to hub    report_to="tensorboard",                # report metrics to tensorboard)

We now have every building block we need to create our SFTTrainer to start then training our model.

from trl import SFTTrainer max_seq_length = 1512 # max sequence length for model and packing of the dataset trainer = SFTTrainer(    model=model,    args=args,    train_dataset=dataset,    peft_config=peft_config,    max_seq_length=max_seq_length,    tokenizer=tokenizer,    packing=True,    dataset_kwargs={        "add_special_tokens": False, # We template with special tokens        "append_concat_token": False, # No need to add additional separator token    })

Start training our model by calling the train() method on our Trainer instance. This will start the training loop and train our model for 3 epochs. Since we are using a PEFT method, we will only save the adapted model weights and not the full model.

# start training, the model will be automatically saved to the hub and the output directorytrainer.train() # save modeltrainer.save_model()

The training with Flash Attention for 3 epochs with a dataset of 15k samples took 4:14:36 on a g5.2xlarge. The instance costs 1.21$/h which brings us to a total cost of only ~5.3$.

Optional: Merge LoRA adapter in to the original model

When using QLoRA, we only train adapters and not the full model. This means when saving the model during training we only save the adapter weights and not the full model. If you want to save the full model, which makes it easier to use with Text Generation Inference you can merge the adapter weights into the model weights using the merge_and_unload method and then save the model with the save_pretrained method.

Check out the How to Fine-Tune LLMs in 2024 with Hugging Face blog post on how to do it .

3. Test Model and run Inference

After the training is done we want to evaluate and test our model. We will load different samples from the original dataset and evaluate the model on those samples, using a simple loop and accuracy as our metric.

Note: Evaluating Generative AI models is not a trivial task since 1 input can have multiple correct outputs. If you want to learn more about evaluating generative models, check out Evaluate LLMs and RAG a practical example using Langchain and Hugging Face blog post.

# free the memory againdel modeldel trainertorch.cuda.empty_cache()

We load the adapted model and the tokenize into the pipeline to easily test it and extract the token id of <|im_end|> to use it in the generate method.

import torchfrom peft import AutoPeftModelForCausalLMfrom transformers import  AutoTokenizer, pipeline peft_model_id = "gemma-7b-dolly-chatml" # Load Model with PEFT adaptertokenizer = AutoTokenizer.from_pretrained(peft_model_id)model = AutoPeftModelForCausalLM.from_pretrained(peft_model_id, device_map="auto", torch_dtype=torch.float16)pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)# get token id for end of conversationeos_token = tokenizer("<|im_end|>",add_special_tokens=False)["input_ids"][0]

Lets test some prompt samples and see how the model performs.

prompts = [    "What is the capital of Germany? Explain why thats the case and if it was different in the past?",    "Write a Python function to calculate the factorial of a number.",    "A rectangular garden has a length of 25 feet and a width of 15 feet. If you want to build a fence around the entire garden, how many feet of fencing will you need?",    "What is the difference between a fruit and a vegetable? Give examples of each.",] def test_inference(prompt):    prompt = pipe.tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True)    outputs = pipe(prompt, max_new_tokens=1024, do_sample=True, temperature=0.7, top_k=50, top_p=0.95, eos_token_id=eos_token)    return outputs[0]['generated_text'][len(prompt):].strip()  for prompt in prompts:    print(f"    prompt:\n{prompt}")    print(f"    response:\n{test_inference(prompt)}")    print("-"*50)

Thanks for reading! If you have any questions, feel free to contact me on Twitter or LinkedIn.

Fish AI Reader

Fish AI Reader

AI辅助创作,多种专业模板,深度分析,高质量内容生成。从观点提取到深度思考,FishAI为您提供全方位的创作支持。新版本引入自定义参数,让您的创作更加个性化和精准。

FishAI

FishAI

鱼阅,AI 时代的下一个智能信息助手,助你摆脱信息焦虑

联系邮箱 441953276@qq.com

相关标签

Gemma LLM微调 Hugging Face TRL ChatML QLoRA Flash Attention 开源大语言模型
相关文章