philschmid RSS feed 09月30日
Donut模型微调教程
index_new5.html
../../../zaker_core/zaker_tpl_static/wap/tpl_guoji1.html

 

本文介绍了如何使用Hugging Face Transformers微调Donut模型进行文档理解与解析。Donut模型是一种无需OCR的文档理解模型,在视觉文档分类和信息提取任务中表现出色。教程涵盖了开发环境设置、SROIE数据集加载、数据准备、模型微调和评估等步骤。通过使用Hugging Face生态系统中的功能,如模型版本控制和实验跟踪,用户可以轻松地微调Donut模型并应用于实际场景。

📚 Donut模型是一种无需OCR的文档理解模型,能够直接处理扫描文档,在视觉文档分类和信息提取任务中达到最先进水平。

🛠️ 教程详细介绍了使用Hugging Face Transformers微调Donut模型的步骤,包括开发环境设置、SROIE数据集加载、数据准备、模型微调和评估等环节。

📊 通过使用Hugging Face生态系统中的功能,如模型版本控制和实验跟踪,用户可以更高效地进行模型微调和实验管理。

🔍 教程中使用的SROIE数据集包含1000张扫描收据及其OCR文本,为模型微调提供了丰富的数据支持。

📈 用户将通过本教程学习如何使用Donut模型进行文档解析,并评估模型的准确性和性能。

In this blog, you will learn how to fine-tune Donut-base for document-understand/document-parsing using Hugging Face Transformers. Donut is a new document-understanding model achieving state-of-art performance with an MIT-license, which allows it to be used for commercial purposes compared to other models like LayoutLMv2/LayoutLMv3.We are going to use all of the great features from the Hugging Face ecosystem, like model versioning and experiment tracking.

We will use the SROIE dataset a collection of 1000 scanned receipts, including their OCR. More information for the dataset can be found at the repository.

You will learn how to:

    Setup Development EnvironmentLoad SROIE datasetPrepare dataset for DonutFine-tune and evaluate Donut model

Before we can start, make sure you have a Hugging Face Account to save artifacts and experiments.

Quick intro: Document Understanding Transformer (Donut) by ClovaAI

Document Understanding Transformer (Donut) is a new Transformer model for OCR-free document understanding. It doesn't require an OCR engine to process scanned documents but is achieving state-of-the-art performances on various visual document understanding tasks, such as visual document classification or information extraction (a.k.a. document parsing).Donut is a multimodal sequence-to-sequence model with a vision encoder (Swin Transformer) and text decoder (BART). The encoder receives the images and computes it into an embedding, which is then passed to the decoder, which generates a sequence of tokens.


Now we know how Donut works, so let's get started. 🚀

Note: This tutorial was created and run on a p3.2xlarge AWS EC2 Instance including a NVIDIA V100.

1. Setup Development Environment

Our first step is to install the Hugging Face Libraries, including transformers and datasets. Running the following cell will install all the required packages.

Note: At the time of writing this Donut is not yet included in the PyPi version of Transformers, so we need it to install from the main branch. Donut will be added in version 4.22.0.

!pip install -q git+https://github.com/huggingface/transformers.git# !pip install -q "transformers>=4.22.0" # comment in when version is released!pip install -q datasets sentencepiece tensorboard# install git-fls for pushing model and logs to the hugging face hub!sudo apt-get install git-lfs --yes

This example will use the Hugging Face Hub as a remote model versioning service. To be able to push our model to the Hub, you need to register on the Hugging Face.If you already have an account, you can skip this step.After you have an account, we will use the notebook_login util from the huggingface_hub package to log into our account and store our token (access key) on the disk.

from huggingface_hub import notebook_login notebook_login()

2. Load SROIE dataset

We will use the SROIE dataset a collection of 1000 scanned receipts including their OCR, more specifically we will use the dataset from task 2 "Scanned Receipt OCR". The available dataset on Hugging Face (darentang/sroie) is not compatible with Donut. Thats why we will use the original dataset together with the imagefolder feature of datasets to load our dataset. Learn more about loading image data here.

Note: The test data for task2 is sadly not available. Meaning that we end up only with 624 images.

First, we will clone the repository, extract the dataset into a separate folder and remove the unnecessary files.

%%bash# clone repositorygit clone https://github.com/zzzDavid/ICDAR-2019-SROIE.git# copy datacp -r ICDAR-2019-SROIE/data ./# clean uprm -rf ICDAR-2019-SROIErm -rf data/box

Now we have two folders inside the data/ directory. One contains the images of the receipts and the other contains the OCR text. The nex step is to create a metadata.json file that contains the information about the images including the OCR-text. This is necessary for the imagefolder feature of datasets.

The metadata.json should look at the end similar to the example below.

{"file_name": "0001.png", "text": "This is a golden retriever playing with a ball"}{"file_name": "0002.png", "text": "A german shepherd"}

In our example will "text" column contain the OCR text of the image, which will later be used for creating the Donut specific format.

import osimport jsonfrom pathlib import Pathimport shutil # define pathsbase_path = Path("data")metadata_path = base_path.joinpath("key")image_path = base_path.joinpath("img")# define metadata listmetadata_list = [] # parse metadatafor file_name in metadata_path.glob("*.json"):  with open(file_name, "r") as json_file:    # load json file    data = json.load(json_file)    # create "text" column with json string    text = json.dumps(data)    # add to metadata list if image exists    if image_path.joinpath(f"{file_name.stem}.jpg").is_file():      metadata_list.append({"text":text,"file_name":f"{file_name.stem}.jpg"})      # delete json file # write jsonline filewith open(image_path.joinpath('metadata.jsonl'), 'w') as outfile:    for entry in metadata_list:        json.dump(entry, outfile)        outfile.write('\n') # remove old meta datashutil.rmtree(metadata_path)

Good Job! Now we can load the dataset using the imagefolder feature of datasets.

import osimport jsonfrom pathlib import Pathimport shutilfrom datasets import load_dataset # define pathsbase_path = Path("data")metadata_path = base_path.joinpath("key")image_path = base_path.joinpath("img") # Load datasetdataset = load_dataset("imagefolder", data_dir=image_path, split="train") print(f"Dataset has {len(dataset)} images")print(f"Dataset features are: {dataset.features.keys()}")

Now, lets take a closer look at our dataset

import random random_sample = random.randint(0, len(dataset)) print(f"Random sample is {random_sample}")print(f"OCR text is {dataset[random_sample]['text']}")dataset[random_sample]['image'].resize((250,400))#     OCR text is {"company": "LIM SENG THO HARDWARE TRADING", "date": "29/12/2017", "address": "NO 7, SIMPANG OFF BATU VILLAGE, JALAN IPOH BATU 5, 51200 KUALA LUMPUR MALAYSIA", "total": "6.00"}

3. Prepare dataset for Donut

As we learned in the introduction, Donut is a sequence-to-sequence model with a vision encoder and text decoder. When fine-tuning the model we want it to generate the "text" based on the image we pass it. Similar to NLP tasks, we have to tokenize and preprocess the text.Before we can tokenize the text, we need to transform the JSON string into a Donut compatible document.

current JSON string

{  "company": "ADVANCO COMPANY",  "date": "17/01/2018",  "address": "NO 1&3, JALAN WANGSA DELIMA 12, WANGSA LINK, WANGSA MAJU, 53300 KUALA LUMPUR",  "total": "7.00"}

Donut document

<s></s><s_company>ADVANCO COMPANY</s_company><s_date>17/01/2018</s_date><s_address>NO 1&3, JALAN WANGSA DELIMA 12, WANGSA LINK, WANGSA MAJU, 53300 KUALA LUMPUR</s_address><s_total>7.00</s_total></s>

To easily create those documents the ClovaAI team has created a json2token method, which we extract and then apply.

new_special_tokens = [] # new tokens which will be added to the tokenizertask_start_token = "<s>"  # start of task tokeneos_token = "</s>" # eos token of tokenizer def json2token(obj, update_special_tokens_for_json_key: bool = True, sort_json_key: bool = True):    """    Convert an ordered JSON object into a token sequence    """    if type(obj) == dict:        if len(obj) == 1 and "text_sequence" in obj:            return obj["text_sequence"]        else:            output = ""            if sort_json_key:                keys = sorted(obj.keys(), reverse=True)            else:                keys = obj.keys()            for k in keys:                if update_special_tokens_for_json_key:                    new_special_tokens.append(fr"<s_{k}>") if fr"<s_{k}>" not in new_special_tokens else None                    new_special_tokens.append(fr"</s_{k}>") if fr"</s_{k}>" not in new_special_tokens else None                output += (                    fr"<s_{k}>"                    + json2token(obj[k], update_special_tokens_for_json_key, sort_json_key)                    + fr"</s_{k}>"                )            return output    elif type(obj) == list:        return r"<sep/>".join(            [json2token(item, update_special_tokens_for_json_key, sort_json_key) for item in obj]        )    else:        # excluded special tokens for now        obj = str(obj)        if f"<{obj}/>" in new_special_tokens:            obj = f"<{obj}/>"  # for categorical special tokens        return obj  def preprocess_documents_for_donut(sample):    # create Donut-style input    text = json.loads(sample["text"])    d_doc = task_start_token + json2token(text) + eos_token    # convert all images to RGB    image = sample["image"].convert('RGB')    return {"image": image, "text": d_doc} proc_dataset = dataset.map(preprocess_documents_for_donut) print(f"Sample: {proc_dataset[45]['text']}")print(f"New special tokens: {new_special_tokens + [task_start_token] + [eos_token]}")#    Sample: <s><s_total>$6.90</s_total><s_date>27 MAR 2018</s_date><s_company>UNIHAKKA INTERNATIONAL SDN BHD</s_company><s_address>12, JALAN TAMPOI 7/4,KAWASAN PARINDUSTRIAN TAMPOI,81200 JOHOR BAHRU,JOHOR</s_address></s>#    New special tokens: ['<s_total>', '</s_total>', '<s_date>', '</s_date>', '<s_company>', '</s_company>', '<s_address>', '</s_address>', '<s>', '</s>']

The next step is to tokenize our text and encode the images into tensors. Therefore we need to load DonutProcessor, add our new special tokens and adjust the size of the images when processing from [1920, 2560] to [720, 960] to need less memory and have faster training.

from transformers import DonutProcessor # Load processorprocessor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base") # add new special tokens to tokenizerprocessor.tokenizer.add_special_tokens({"additional_special_tokens": new_special_tokens + [task_start_token] + [eos_token]}) # we update some settings which differ from pretraining; namely the size of the images + no rotation required# resizing the image to smaller sizes from [1920, 2560] to [960,1280]processor.feature_extractor.size = [720,960] # should be (width, height)processor.feature_extractor.do_align_long_axis = False

Now, we can prepare our dataset, which we will use for the training later.

def transform_and_tokenize(sample, processor=processor, split="train", max_length=512, ignore_id=-100):    # create tensor from image    try:        pixel_values = processor(            sample["image"], random_padding=split == "train", return_tensors="pt"        ).pixel_values.squeeze()    except Exception as e:        print(sample)        print(f"Error: {e}")        return {}     # tokenize document    input_ids = processor.tokenizer(        sample["text"],        add_special_tokens=False,        max_length=max_length,        padding="max_length",        truncation=True,        return_tensors="pt",    )["input_ids"].squeeze(0)     labels = input_ids.clone()    labels[labels == processor.tokenizer.pad_token_id] = ignore_id  # model doesn't need to predict pad token    return {"pixel_values": pixel_values, "labels": labels, "target_sequence": sample["text"]} # need at least 32-64GB of RAM to run thisprocessed_dataset = proc_dataset.map(transform_and_tokenize,remove_columns=["image","text"])
# from datasets import load_from_disk# from transformers import DonutProcessor ## COMMENT IN in case you want to save the processed dataset to disk in case of error later# processed_dataset.save_to_disk("processed_dataset")# processor.save_pretrained("processor") ## COMMENT IN in case you want to load the processed dataset from disk in case of error later# processed_dataset = load_from_disk("processed_dataset")# processor = DonutProcessor.from_pretrained("processor") 

The last step is to split the dataset into train and validation sets.

processed_dataset = processed_dataset.train_test_split(test_size=0.1)print(processed_dataset)

4. Fine-tune and evaluate Donut model

After we have processed our dataset, we can start training our model. Therefore we first need to load the naver-clova-ix/donut-base model with the VisionEncoderDecoderModel class. The donut-base includes only the pre-trained weights and was introduced in the paper OCR-free Document Understanding Transformer by Geewok et al. and first released in this repository.

In addition to loading our model, we are resizing the embedding layer to match newly added tokens and adjusting the image_size of our encoder to match our dataset. We are also adding tokens for inference later.

import torchfrom transformers import VisionEncoderDecoderModel, VisionEncoderDecoderConfig # Load model from huggingface.comodel = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base") # Resize embedding layer to match vocabulary sizenew_emb = model.decoder.resize_token_embeddings(len(processor.tokenizer))print(f"New embedding size: {new_emb}")# Adjust our image size and output sequence lengthsmodel.config.encoder.image_size = processor.feature_extractor.size[::-1] # (height, width)model.config.decoder.max_length = len(max(processed_dataset["train"]["labels"], key=len)) # Add task token for decoder to startmodel.config.pad_token_id = processor.tokenizer.pad_token_idmodel.config.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids(['<s>'])[0] # is done by Trainer# device = "cuda" if torch.cuda.is_available() else "cpu"# model.to(device)

Before we can start our training we need to define the hyperparameters (Seq2SeqTrainingArguments) we want to use for our training. We are leveraging the Hugging Face Hub integration of the Seq2SeqTrainer to automatically push our checkpoints, logs and metrics during training into a repository.

from huggingface_hub import HfFolderfrom transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer # hyperparameters used for multiple argshf_repository_id = "donut-base-sroie" # Arguments for trainingtraining_args = Seq2SeqTrainingArguments(    output_dir=hf_repository_id,    num_train_epochs=3,    learning_rate=2e-5,    per_device_train_batch_size=2,    weight_decay=0.01,    fp16=True,    logging_steps=100,    save_total_limit=2,    evaluation_strategy="no",    save_strategy="epoch",    predict_with_generate=True,    # push to hub parameters    report_to="tensorboard",    push_to_hub=True,    hub_strategy="every_save",    hub_model_id=hf_repository_id,    hub_token=HfFolder.get_token(),) # Create Trainertrainer = Seq2SeqTrainer(    model=model,    args=training_args,    train_dataset=processed_dataset["train"],)

We can start our training by using the train method of the Seq2SeqTrainer.

# Start trainingtrainer.train()

After our training is done we also want to save our processor to the Hugging Face Hub and create a model card.

# Save processor and create model cardprocessor.save_pretrained(hf_repository_id)trainer.create_model_card()trainer.push_to_hub()

We sucessfully trainied our model now lets test it and then evaulate accuracy of it.

import reimport transformersfrom PIL import Imagefrom transformers import DonutProcessor, VisionEncoderDecoderModelimport torchimport randomimport numpy as np # hidde logstransformers.logging.disable_default_handler()  # Load our model from Hugging Faceprocessor = DonutProcessor.from_pretrained("philschmid/donut-base-sroie")model = VisionEncoderDecoderModel.from_pretrained("philschmid/donut-base-sroie") # Move model to GPUdevice = "cuda" if torch.cuda.is_available() else "cpu"model.to(device) # Load random document image from the test settest_sample = processed_dataset["test"][random.randint(1, 50)] def run_prediction(sample, model=model, processor=processor):    # prepare inputs    pixel_values = torch.tensor(test_sample["pixel_values"]).unsqueeze(0)    task_prompt = "<s>"    decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids     # run inference    outputs = model.generate(        pixel_values.to(device),        decoder_input_ids=decoder_input_ids.to(device),        max_length=model.decoder.config.max_position_embeddings,        early_stopping=True,        pad_token_id=processor.tokenizer.pad_token_id,        eos_token_id=processor.tokenizer.eos_token_id,        use_cache=True,        num_beams=1,        bad_words_ids=[[processor.tokenizer.unk_token_id]],        return_dict_in_generate=True,    )     # process output    prediction = processor.batch_decode(outputs.sequences)[0]    prediction = processor.token2json(prediction)     # load reference target    target = processor.token2json(test_sample["target_sequence"])    return prediction, target prediction, target = run_prediction(test_sample)print(f"Reference:\n {target}")print(f"Prediction:\n {prediction}")processor.feature_extractor.to_pil_image(np.array(test_sample["pixel_values"])).resize((350,600))

Result

Reference:{'total': '9.30', 'date': '26/11/2017', 'company': 'SANYU STATIONERY SHOP', 'address': 'NO. 31G&33G, JALAN SETIA INDAH X ,U13/X 40170 SETIA ALAM'} Prediction:{'total': '9.30', 'date': '26/11/2017', 'company': 'SANYU STATIONERY SHOP', 'address': 'NO. 31G&33G, JALAN SETIA INDAH X,U13/X 40170 SETIA ALAM'}

Nice 😍🔥 Our fine-tuned parsed the document correctly and extracted the right values. Our next step is to evalute our model on the test set. Since the model itself is a seq2seq is not that straightforward to evaluate.

To keep things simple we will use accuracy as metric and compare the predicted value for each key in the dictionary to see if they are equal. This evaluation technique is biased/simple sincne only exact matches are truthy, e.g. if the model is not detecting a "whitespace" as in the example above it will not be counted truthy.

from tqdm import tqdm # define counter for samplestrue_counter = 0total_counter = 0 # iterate over datasetfor sample in tqdm(processed_dataset["test"]):  prediction, target = run_prediction(test_sample)  for s in zip(prediction.values(), target.values()):    if s[0] == s[1]:      true_counter += 1    total_counter += 1 print(f"Accuracy: {(true_counter/total_counter)*100}%")# Accuracy: 75.0%

Our model achieves an accuracy of 75% on the test set.

Note: The evaluation we did was very simple and only valued exact string matches as "truthy" for each key of the dictonary, is a big bias for the evaluation. Meaning that a accuracy of 75% is pretty good.

Our first inference test is an excellent example of why this metric is biased. There the model predicted for the address the value NO. 31G&33G, JALAN SETIA INDAH X ,U13/X 40170 SETIA ALAM and the ground truth was 'NO. 31G&33G, JALAN SETIA INDAH X,U13/X 40170 SETIA ALAM', where the only difference is the whitespace in between X and ,U13/X.In our evaluation loop, this was not counted as a truthy value.


Thanks for reading! If you have any questions, feel free to contact me, through Github, or on the forum. You can also connect with me on Twitter or LinkedIn.

Fish AI Reader

Fish AI Reader

AI辅助创作,多种专业模板,深度分析,高质量内容生成。从观点提取到深度思考,FishAI为您提供全方位的创作支持。新版本引入自定义参数,让您的创作更加个性化和精准。

FishAI

FishAI

鱼阅,AI 时代的下一个智能信息助手,助你摆脱信息焦虑

联系邮箱 441953276@qq.com

相关标签

Donut模型 文档理解 Hugging Face 机器学习 自然语言处理