Huggingface 核心模块(二): datasets
huggingface
本文字数:4.5k 字 | 阅读时长 ≈ 22 min

Huggingface 核心模块(二): datasets

huggingface
本文字数:4.5k 字 | 阅读时长 ≈ 22 min

Huggingface 中的 datasets 是一个非常重要的类,他可以帮助我们快速地加载数据集,同时还可以对数据集进行处理,包括数据集的划分、缓存、下载等。本文将详细介绍 datasets 的使用方法,包括 datasets 的基本使用、一些基本方法以及自定义自己的数据集

1. datasets 基本使用

1.1 获取 Huggingface Hub 上的数据集

Huggingface 的官网所有已经存在的数据集,可以直接从Huggingface Hub下载并使用,整个过程可以通过 load_datasets(官方文档)来完成,关键参数如下

加载 Huggingface Hub 的 imdb 数据集

>>> dataset = load_dataset("imdb")
Downloading metadata: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 2.17k/2.17k [00:00<00:00, 6.82MB/s]
Downloading readme: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 7.59k/7.59k [00:00<00:00, 11.0MB/s]
Downloading and preparing dataset imdb/plain_text to /Users/harry/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0...
Downloading data: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 84.1M/84.1M [00:48<00:00, 1.73MB/s]
Dataset imdb downloaded and prepared to /Users/harry/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0. Subsequent calls will reuse this data.
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 152.18it/s]
>>> dataset
DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    unsupervised: Dataset({
        features: ['text', 'label'],
        num_rows: 50000
    })
})
>>> 

只需要一句代码就能够从 Huggingface Hub 中下载 imdb 数据集,我们将下载的数据集输出,会有一个 DatasetDict 对象,包含了 traintestunsupervised 三种数据类型,下面介绍数据集的一些参数和使用

1.2 加载本地数据集

这里以意大利语的问答数据集 squad_it 为例,演示如何下载和处理本地数据集

1. 手动下载并解压 squad_it 数据集

我们首先手动下载 squad_it 数据集看一下其数据形式

>>> wget https://github.com/crux82/squad-it/raw/master/SQuAD_it-train.json.gz
>>> wget https://github.com/crux82/squad-it/raw/master/SQuAD_it-test.json.gz

# 解压
>>> gzip -dkv SQuAD_it-*.json.gz 
SQuAD_it-train.json.gz:    82.2% -- replaced with SQuAD_it-train.json
SQuAD_it-test.json.gz:     87.4% -- replaced with SQuAD_it-test.json

下载完后我们看一下数据集内容

{
    "data": [
        {
            "title": "Terremoto del Sichuan del 2008",
            "paragraphs": [
                {
                    "context": "Il terremoto del Sichuan del 2008 o il terremoto del Gran Sichuan, misurato a 8.0 Ms e 7.9 Mw, e si è verificato alle 02:28:01 PM China Standard Time all' epicentro (06:28:01 UTC) il 12 maggio nella provincia del Sichuan, ha ucciso 69.197 persone e lasciato 18.222 dispersi.",
                    "qas": [
                        {
                            "id": "56cdca7862d2951400fa6826",
                            "answers": [
                                {
                                    "text": "2008",
                                    "answer_start": 29
                                }
                            ],
                            "question": "In quale anno si è verificato il terremoto nel Sichuan?"
                        },
                        ...

标题格式为

|-data
    |-title
    |-paragraphs
        |-context
        |-qas
            |-id
            |-answers
                |-text
                |-answer_start
            |-question
            |-id
            |-answers
                |-text
                |-answer_start
            |-question

接下来我们使用 load_dataset 函数来加载这个数据集

2. 使用 load_dataset 加载

直接用 load_dataset 加载数据,以 SQuAD_it-train.json 为例

>>> squad_it_dataset = load_dataset("json", data_files="SQuAD_it-train.json", field="data")
Downloading and preparing dataset json/default to /Users/harry/.cache/huggingface/datasets/json/default-e0b956320ae13300/0.0.0/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96...
Downloading data files: 100%|███████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 8774.69it/s]
Extracting data files: 100%|█████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 463.15it/s]
Dataset json downloaded and prepared to /Users/harry/.cache/huggingface/datasets/json/default-e0b956320ae13300/0.0.0/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96. Subsequent calls will reuse this data.
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 355.36it/s]
>>> print(squad_it_dataset)
DatasetDict({
    train: Dataset({
        features: ['title', 'paragraphs'],
        num_rows: 442
    })
})

load_dataset 函数里面的 json 表示加载的数据集脚本是 json 类型,data_files 表示数据集的路径,field 表示哪个域名对应的字段,这里第一部分是 data,所以我们设为 data

加载完后,返回的 datasetDatasetDict 类型,当我们不做设置时,默认会将加载的数据集设置为 train dataset,我们可以直接对其进行索引得到数据集中的数据

>>> squad_it_dataset["train"][0]
{
    "title": "Terremoto del Sichuan del 2008",
    "paragraphs": [
        {
            "context": "Il terremoto del Sichuan del 2008 o il terremoto...",
            "qas": [
                {
                    "answers": [{"answer_start": 29, "text": "2008"}],
                    "id": "56cdca7862d2951400fa6826",
                    "question": "In quale anno si è verificato il terremoto nel Sichuan?",
                },
                ...

如上所示,他会输出第一个 title 以及 paragraphis 里面的内容,这是因为这一部分是 data 这个列表里的第一个数据

3. 更灵活的加载 train 和 test 数据集

默认加载的文件会被设定为 train,如果我们想分别设置 traintest,可以对 load_datasetdata_files 参数设置字典对象

>>> data_files = {"train": "SQuAD_it-train.json", "test": "SQuAD_it-test.json"}
>>> squad_it_dataset = load_dataset("json", data_files=data_files, field="data")
Downloading and preparing dataset json/default to /Users/harry/.cache/huggingface/datasets/json/default-deaf5fe77027f091/0.0.0/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96...
Downloading data files: 100%|████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 12846.26it/s]
Extracting data files: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 643.50it/s]
Dataset json downloaded and prepared to /Users/harry/.cache/huggingface/datasets/json/default-deaf5fe77027f091/0.0.0/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96. Subsequent calls will reuse this data.
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 462.97it/s]
>>> squad_it_dataset
DatasetDict({
    train: Dataset({
        features: ['paragraphs', 'title'],
        num_rows: 442
    })
    test: Dataset({
        features: ['paragraphs', 'title'],
        num_rows: 48
    })
})

上述例子中,我们提前将数据集下载下来,然后通过 data_files 参数设置 train 和 test。此外,如果数据集的格式为压缩文件或者是存储在云端,load_dataset 依然能够自动解压或者下载,如下

# load_dataset能够自动解压
data_files = {"train": "SQuAD_it-train.json.gz", "test": "SQuAD_it-test.json.gz"}
squad_it_dataset = load_dataset("json", data_files=data_files, field="data")

# load_dataset能够自动下载
url = "https://github.com/crux82/squad-it/raw/master/"
data_files = {
    "train": url + "SQuAD_it-train.json.gz",
    "test": url + "SQuAD_it-test.json.gz",
}
squad_it_dataset = load_dataset("json", data_files=data_files, field="data")

2. 对数据集进行操作

下面 datasets 的操作方法都可以在官方文档找到

2.1 filter 方法

filter 方法能够过滤数据集中的特定数据,例如在 glue/mrpc 数据集中我们想保留 sentence1 中第一个字符为 " 的句子,过滤其他句子,我们可以通过如下代码实现

>>> raw_datasets = load_dataset("glue", "mrpc")
Found cached dataset glue (/Users/harry/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 375.89it/s]
>>> sentence_sample = raw_datasets['train'].shuffle(seed=42).select(range(100))
>>> sentence_sample[:3]
{'sentence1': ['" The public is understandably losing patience with these unwanted phone calls , unwanted intrusions , " he said at a White House ceremony .', 'Federal agent Bill Polychronopoulos said it was not known if the man , 30 , would be charged .', 'The companies uniformly declined to give specific numbers on customer turnover , saying they will release those figures only when they report overall company performance at year-end .'], 'sentence2': ['" While many good people work in the telemarketing industry , the public is understandably losing patience with these unwanted phone calls , unwanted intrusions , " Mr. Bush said .', 'Federal Agent Bill Polychronopoulos said last night the man involved in the Melbourne incident had been unarmed .', 'The companies , however , declined to give specifics on customer turnover , saying they would release figures only when they report their overall company performance .'], 'label': [0, 0, 1], 'idx': [3946, 3683, 3919]}
>>> sentence_sample = raw_datasets['train']
>>> print(sentence_sample)
Dataset({
    features: ['sentence1', 'sentence2', 'label', 'idx'],
    num_rows: 3668
})
>>> sentence_sample = sentence_sample.filter(lambda x: x["sentence1"][0]=="\"")
>>> print(sentence_sample)
Dataset({
    features: ['sentence1', 'sentence2', 'label', 'idx'],
    num_rows: 343
})
>>> sentence_sample[:3]
{'sentence1': ['" I think you \'ll see a lot of job growth in the next two years , " he said , adding the growth could replace jobs lost .', '" The result is an overall package that will provide significant economic growth for our employees over the next four years . "', '" We are declaring war on sexual harassment and sexual assault .'], 'sentence2': ['" I think you \'ll see a lot of job growth in the next two years , " said Mankiw .', '" The result is an overall package that will provide a significant economic growth for our employees over the next few years , " he said .', '" We have declared war on sexual assault and sexual harassment , " Rosa said .'], 'label': [0, 1, 1], 'idx': [20, 49, 89]}
>>> 

上述 filter 中的匿名函数 lambda 也可以替换为正常定义的函数,例如

def filter_quote(x): 
    return x["sentence1"][0]=="\""

2.2 map 方法

map 方法允许我们对数据集中的每个样本进行操作,基本参数如下

1. 将 glue/mrpc 数据集 sentence1 中的句子全部变为大写

>>> sentence_sample = raw_datasets['train']
>>> def upper_sentence(example):
...     return example["sentence1"] = example["sentence1"].upper()
... 
>>> sentence_sample = sentence_sample.map(upper_sentence)
>>> sentence_sample[0]
{'sentence1': 'AMROZI ACCUSED HIS BROTHER , WHOM HE CALLED " THE WITNESS " , OF DELIBERATELY DISTORTING HIS EVIDENCE .', 'sentence2': 'Referring to him as only " the witness " , Amrozi accused his brother of deliberately distorting his evidence .', 'label': 1, 'idx': 0}
>>> raw_datasets['train']['sentence1'][0]
'Amrozi accused his brother , whom he called " the witness " , of deliberately distorting his evidence .'
>>> sentence_sample['sentence1'][0]
'AMROZI ACCUSED HIS BROTHER , WHOM HE CALLED " THE WITNESS " , OF DELIBERATELY DISTORTING HIS EVIDENCE .'

2. 给 glue/mrpc 数据集增加一列,显示 sentence1sentence2 每个句子的长度

注意这里的返回值必须是字典形式,返回的字典会自动给数据集增加一列。同样的,如果想修改原有的数据集参考示例 1 即可

>>> sentence_sample = raw_datasets['train']
>>> def compute_length(example):
...     sentence_length = {}
...     sentence_length["sentence1_len"] = len(example["sentence1"])
...     return sentence_length
... 
>>> sentence_sample = sentence_sample.map(compute_length)
>>> sentence_sample
Dataset({
    features: ['sentence1', 'sentence2', 'label', 'idx', 'sentence1_len'],
    num_rows: 3668
})
>>> sentence_sample[0]
{'sentence1': 'Amrozi accused his brother , whom he called " the witness " , of deliberately distorting his evidence .', 'sentence2': 'Referring to him as only " the witness " , Amrozi accused his brother of deliberately distorting his evidence .', 'label': 1, 'idx': 0, 'sentence1_len': 103}
>>> 

2.3 rename_column 方法

rename_column 方法允许我们对每一列重新命名,即标签重命名。例如我们想将 glue/mrpc 数据集中的 sentence1sentence2 重新命名为 sen1sen2

>>> sentence_sample = raw_datasets['train']
>>> sentence_sample = sentence_sample.rename_column(original_column_name="sentence1", new_column_name="sen1")
>>> sentence_sample = sentence_sample.rename_column(original_column_name="sentence2", new_column_name="sen2")
>>> print(sentence_sample)
Dataset({
    features: ['sen1', 'sen2', 'label', 'idx'],
    num_rows: 3668
})

2.4 其他方法

参考

1. sort

对数据集进行排序

sortData = dataset.sort('label')

2. shuffle

打乱数据集

shuffleData = sortData.shuffle(seed=20)

3. select

选择数据集中指定索引

dataset.select([0,1,2,3])

4. filter

过滤数据集

def filter(data):
    return data['text'].startswith('1')
b = dataset.filter(filter)

5. train_test_split

将数据集切分为训练集和测试集,例如将数据集中的 10%切分为测试集

dataset.train_test_split(test_size=0.1)

6. shard

将数据集氛围若干份,例如将数据集切分为 5 份,并取出其中第一份

dataset.shard(num_shards=5, index=0)

7. rename_column

对列重新命名,这个操作比 map 中的 rename_column 方法速度更快,因为不需要 copy 新数据

c = a.rename_column('text', 'newColumn')

8. remove_columns

删除某一列

d = c.remove_columns(['newColumn'])

9. map

对数据集中的某个数据进行处理

def handler(data):
  data['text'] = 'Prefix' + data['text']
  return data

datasetMap = dataset.map(handler)

10. save_to_disk/load_from_disk

保存和加载数据

dataset.save_to_disk('./')

from datasets import load_from_disk
dataset = load_from_disk('./')

3. 自定义数据加载脚本

在上面我们直接调用了 datasets 中的 load_dataset 方法,从 Huggingface Hub 或者本地加载数据集,但是如果数据集非常复杂,我们想自定义加载脚本怎么做呢?Huggingface 有一套非常完整灵活的数据加载脚本,我们可以参考这里进行自定义数据加载脚本,同时官方还给了一个数据集加载脚本的示例

3.1 完善两个核心类

自定义数据脚本需要完善两个核心类,分别是 datasets.BuilderConfigdatasets.GeneratorBasedBuilder,第一个类主要是维护数据集的各种信息,第二个类主要是实现数据集的加载和处理,最后数据加载格式如下

from datasets import BuilderConfig, GeneratorBasedBuilder

class MyBuilderConfig(BuilderConfig):
    def __init__(self, **kwargs):
        super(MyBuilderConfig, self).__init__(**kwargs)
        pass

class MyDatasetBuilder(GeneratorBasedBuilder):
    BUILDER_CONFIGS = [pass]

    def _info(self):
        pass
    def _split_generators(self, dl_manager):
        pass
    def _generate_examples(self, filepath):
        pass

3.2 BuilderConfig 的构建

BuilderConfig 具有自定义属性

BuilderConfig(name="first_domain", version=VERSION, description="This part of my dataset covers a first domain")
BuilderConfig(name="second_domain", version=VERSION, description="This part of my dataset covers a second domain"),

同时我们还可以集成这个类,加入自定义属性,在官网示例中,使用 SuperGlueConfig 继承 BuilderConfig,并添加了 label_classes 等属性

class SuperGlueConfig(datasets.BuilderConfig):
    """BuilderConfig for SuperGLUE."""

    def __init__(self, features, data_url, citation, url, label_classes=("False", "True"), **kwargs):
        """BuilderConfig for SuperGLUE.

        Args:
        features: *list[string]*, list of the features that will appear in the
            feature dict. Should not include "label".
        data_url: *string*, url to download the zip file from.
        citation: *string*, citation for the data set.
        url: *string*, url for information about the data set.
        label_classes: *list[string]*, the list of classes for the label if the
            label is present as a string. Non-string labels will be cast to either
            'False' or 'True'.
        **kwargs: keyword arguments forwarded to super.
        """
        # Version history:
        # 1.0.2: Fixed non-nondeterminism in ReCoRD.
        # 1.0.1: Change from the pre-release trial version of SuperGLUE (v1.9) to
        #        the full release (v2.0).
        # 1.0.0: S3 (new shuffling, sharding and slicing mechanism).
        # 0.0.2: Initial version.
        super().__init__(version=datasets.Version("1.0.2"), **kwargs)
        self.features = features
        self.label_classes = label_classes
        self.data_url = data_url
        self.citation = citation
        self.url = url

3.3 GeneratorBasedBuilder 的构建

GeneratorBasedBuilder 是用来下载和处理数据集的类,我们需要继承这个类,并实现三个关键的方法

1. _info()方法

这个方法主要是定义数据集的信息,包括数据集的 featureshomepagecitation 等,其中 features 是一个 datasets.Features 对象,我们可以通过 datasets.Value 来定义数据集的每个特征。下面是一个示例,该方法只需要 return 一个实例化后的 datasets.DatasetInfo 类对象即可

def _info(self):
    return datasets.DatasetInfo(
        description=_DESCRIPTION,
        features=datasets.Features(
            {
                "id": datasets.Value("string"),
                # others
            }
        ),
        supervised_keys=None,
        homepage="https://<url>/",
        citation=_CITATION,
    )

2. _split_generators

此方法主要是下载和自定义处理数据集,例如分别处理训练集、验证集、测试集等,我们需要 return 一个 datasets.SplitGenerator 的列表,其中 name 是数据集的子集名称,gen_kwargs 是一个字典,包含了 filepathsplit 等信息,这些信息会传递给 _generate_examples 方法

def _split_generators(self, dl_manager):
    """Returns SplitGenerators."""
    downloaded_file = dl_manager.download_and_extract("https://url/dataset.zip")
    return [
        datasets.SplitGenerator(
            name=datasets.Split.TRAIN, gen_kwargs={"filepath": f"{downloaded_file}/dataset/training_data/"}
        ),
        datasets.SplitGenerator(
            name=datasets.Split.TEST, gen_kwargs={"filepath": f"{downloaded_file}/dataset/testing_data/"}
        ),
    ]

3. _generate_examples

此方法用来处理数据集,他接受 _split_generators 传入的 SplitGenerator 类对象,将数据集处理为我们想要的形式

def _generate_examples(self, filepath):
    logger.info("⏳ Generating examples from = %s", filepath)
    ann_dir = os.path.join(filepath, "annotations")
    img_dir = os.path.join(filepath, "images")
    for guid, file in enumerate(sorted(os.listdir(ann_dir))):
        # process data code
        yield guid, {"id": str(guid), "tokens": tokens, "bboxes": bboxes}

3.4 加载数据集

class Funsd(datasets.GeneratorBasedBuilder):
    """Conll2003 dataset."""

    BUILDER_CONFIGS = [
        FunsdConfig(name="funsd", version=datasets.Version("1.0.0"), description="FUNSD dataset"),
    ]

    def _info(self):
        return datasets.DatasetInfo(
            description=_DESCRIPTION,
            features=datasets.Features(
                {
                    "id": datasets.Value("string"),
                    "tokens": datasets.Sequence(datasets.Value("string")),
                    "bboxes": datasets.Sequence(datasets.Sequence(datasets.Value("int64"))),
                    "ner_tags": datasets.Sequence(
                        datasets.features.ClassLabel(
                            names=["O", "B-HEADER", "I-HEADER", "B-QUESTION", "I-QUESTION", "B-ANSWER", "I-ANSWER"]
                        )
                    ),
                    "image": datasets.Array3D(shape=(3, 224, 224), dtype="uint8"),
                }
            ),
            supervised_keys=None,
            homepage="https://guillaumejaume.github.io/FUNSD/",
            citation=_CITATION,
        )

    def _split_generators(self, dl_manager):
        """Returns SplitGenerators."""
        downloaded_file = dl_manager.download_and_extract("https://guillaumejaume.github.io/FUNSD/dataset.zip")
        return [
            datasets.SplitGenerator(
                name=datasets.Split.TRAIN, gen_kwargs={"filepath": f"{downloaded_file}/dataset/training_data/"}
            ),
            datasets.SplitGenerator(
                name=datasets.Split.TEST, gen_kwargs={"filepath": f"{downloaded_file}/dataset/testing_data/"}
            ),
        ]

    def _generate_examples(self, filepath):
        logger.info("⏳ Generating examples from = %s", filepath)
        ann_dir = os.path.join(filepath, "annotations")
        img_dir = os.path.join(filepath, "images")
        for guid, file in enumerate(sorted(os.listdir(ann_dir))):
            tokens = []
            bboxes = []
            ner_tags = []

            file_path = os.path.join(ann_dir, file)
            with open(file_path, "r", encoding="utf8") as f:
                data = json.load(f)
            image_path = os.path.join(img_dir, file)
            image_path = image_path.replace("json", "png")
            image, size = load_image(image_path)
            for item in data["form"]:
                words, label = item["words"], item["label"]
                words = [w for w in words if w["text"].strip() != ""]
                if len(words) == 0:
                    continue
                if label == "other":
                    for w in words:
                        tokens.append(w["text"])
                        ner_tags.append("O")
                        bboxes.append(normalize_bbox(w["box"], size))
                else:
                    tokens.append(words[0]["text"])
                    ner_tags.append("B-" + label.upper())
                    bboxes.append(normalize_bbox(words[0]["box"], size))
                    for w in words[1:]:
                        tokens.append(w["text"])
                        ner_tags.append("I-" + label.upper())
                        bboxes.append(normalize_bbox(w["box"], size))

            yield guid, {"id": str(guid), "tokens": tokens, "bboxes": bboxes, "ner_tags": ner_tags, "image": image}

上传数据集

huggingface-cli upload harrytea/LLaVA-small-data /apdcephfs/llm-cfs-nj/person/harryyhwang/dataset_test . --repo-type=dataset

上传大文件

from huggingface_hub import HfApi

api = HfApi()
api.upload_large_folder(
    repo_id="harrytea/LLaVA-small-data",
    repo_type="dataset",
    folder_path="/llm-cfs-nj/person/harryyhwang/compressed/dataset/small",
)

small下面的文件会上传到harrytea/LLaVA-small-data中

(internvl) root@ts-8b1d822e9463bdfd0194817d7a552ad7-launcher:/llm-cfs-nj/person/harryyhwang# python upload_hf.py
Recovering from metadata files: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 1176/1176 [00:25<00:00, 45.46it/s]



---------- 2025-01-22 16:06:17 (0:00:00) ----------
---------- 2025-01-22 16:08:17 (0:02:00) ----------
Files:   hashed 170/1176 (9.7G/1.5T) | pre-uploaded: 0/1 (0.0/1.5T) (+1175 unsure) | committed: 0/1176 (0.0/1.5T) | ignored: 0
Workers: hashing: 183 | get upload mode: 7 | pre-uploading: 0 | committing: 0 | waiting: 0
---------------------------------------------------