kiln_ai.datamodel.dataset_split

Tools for splitting datasets into train/test/validation splits. Includes filters for selecting which task runs to include in each split.

  1"""
  2Tools for splitting datasets into train/test/validation splits. Includes filters for selecting which task runs to include in each split.
  3"""
  4
  5import math
  6import random
  7from typing import TYPE_CHECKING
  8
  9from pydantic import BaseModel, Field, model_validator
 10
 11from kiln_ai.datamodel.basemodel import FilenameString, KilnParentedModel
 12from kiln_ai.datamodel.dataset_filters import (
 13    DatasetFilter,
 14    DatasetFilterId,
 15    dataset_filter_from_id,
 16)
 17
 18if TYPE_CHECKING:
 19    from kiln_ai.datamodel.task import Task
 20
 21
 22class DatasetSplitDefinition(BaseModel):
 23    """
 24    A definition of a split in a dataset.
 25
 26    Example: name="train", description="The training set", percentage=0.8 (80% of the dataset)
 27    """
 28
 29    name: FilenameString = Field(
 30        description="The name of the dataset split definition."
 31    )
 32    description: str | None = Field(
 33        default=None,
 34        description="A description of the dataset for you and your team. Not used in training.",
 35    )
 36    percentage: float = Field(
 37        ge=0.0,
 38        le=1.0,
 39        description="The percentage of the dataset that this split represents (between 0 and 1).",
 40    )
 41
 42
 43AllSplitDefinition: list[DatasetSplitDefinition] = [
 44    DatasetSplitDefinition(name="all", percentage=1.0)
 45]
 46Train80Test20SplitDefinition: list[DatasetSplitDefinition] = [
 47    DatasetSplitDefinition(name="train", percentage=0.8),
 48    DatasetSplitDefinition(name="test", percentage=0.2),
 49]
 50Train80Val20SplitDefinition: list[DatasetSplitDefinition] = [
 51    DatasetSplitDefinition(name="train", percentage=0.8),
 52    DatasetSplitDefinition(name="val", percentage=0.2),
 53]
 54Train60Test20Val20SplitDefinition: list[DatasetSplitDefinition] = [
 55    DatasetSplitDefinition(name="train", percentage=0.6),
 56    DatasetSplitDefinition(name="test", percentage=0.2),
 57    DatasetSplitDefinition(name="val", percentage=0.2),
 58]
 59Train80Test10Val10SplitDefinition: list[DatasetSplitDefinition] = [
 60    DatasetSplitDefinition(name="train", percentage=0.8),
 61    DatasetSplitDefinition(name="test", percentage=0.1),
 62    DatasetSplitDefinition(name="val", percentage=0.1),
 63]
 64
 65
 66class DatasetSplit(KilnParentedModel):
 67    """
 68    A collection of task runs, with optional splits (train, test, validation).
 69
 70    Used to freeze a dataset into train/test/validation splits for repeatable fine-tuning or other tasks.
 71
 72    Maintains a list of IDs for each split, to avoid data duplication.
 73    """
 74
 75    name: FilenameString = Field(description="The name of the dataset split.")
 76    description: str | None = Field(
 77        default=None,
 78        description="A description of the dataset for you and your team. Not used in training.",
 79    )
 80    splits: list[DatasetSplitDefinition] = Field(
 81        default_factory=list,
 82        description="The splits in the dataset.",
 83    )
 84    split_contents: dict[str, list[str]] = Field(
 85        description="The contents of each split in the dataset. The key is the split name, and the value is a list of task run IDs.",
 86    )
 87    filter: DatasetFilterId | None = Field(
 88        default=None,
 89        description="The filter used to build the dataset.",
 90    )
 91
 92    @model_validator(mode="after")
 93    def validate_split_percentages(self) -> "DatasetSplit":
 94        total = sum(split.percentage for split in self.splits)
 95        if not math.isclose(total, 1.0, rel_tol=1e-9):
 96            raise ValueError(f"The sum of split percentages must be 1.0 (got {total})")
 97        return self
 98
 99    @classmethod
100    def from_task(
101        cls,
102        name: str,
103        task: "Task",
104        splits: list[DatasetSplitDefinition],
105        filter_id: DatasetFilterId = "all",
106        description: str | None = None,
107    ):
108        """
109        Build a dataset split from a task.
110        """
111        filter = dataset_filter_from_id(filter_id)
112        split_contents = cls.build_split_contents(task, splits, filter)
113        return cls(
114            parent=task,
115            name=name,
116            description=description,
117            splits=splits,
118            split_contents=split_contents,
119            filter=filter_id,
120        )
121
122    @classmethod
123    def build_split_contents(
124        cls,
125        task: "Task",
126        splits: list[DatasetSplitDefinition],
127        filter: DatasetFilter,
128    ) -> dict[str, list[str]]:
129        valid_ids = []
130        for task_run in task.runs():
131            if filter(task_run):
132                valid_ids.append(task_run.id)
133
134        # Shuffle and split by split percentage
135        random.shuffle(valid_ids)
136        split_contents = {}
137        start_idx = 0
138        remaining_items = len(valid_ids)
139
140        # Handle all splits except the last one
141        for split in splits[:-1]:
142            split_size = round(len(valid_ids) * split.percentage)
143            split_contents[split.name] = valid_ids[start_idx : start_idx + split_size]
144            start_idx += split_size
145            remaining_items -= split_size
146
147        # Last split gets all remaining items (for rounding)
148        if splits:
149            split_contents[splits[-1].name] = valid_ids[start_idx:]
150
151        return split_contents
152
153    def parent_task(self) -> "Task | None":
154        # inline import to avoid circular import
155        from kiln_ai.datamodel import Task
156
157        if not isinstance(self.parent, Task):
158            return None
159        return self.parent
160
161    def missing_count(self) -> int:
162        """
163        Returns:
164            int: the number of task runs that have an ID persisted in this dataset split, but no longer exist in the dataset
165        """
166        parent = self.parent_task()
167        if parent is None:
168            raise ValueError("DatasetSplit has no parent task")
169
170        runs = parent.runs(readonly=True)
171        all_ids = set(run.id for run in runs)
172        all_ids_in_splits = set()
173        for ids in self.split_contents.values():
174            all_ids_in_splits.update(ids)
175        missing = all_ids_in_splits - all_ids
176        return len(missing)
class DatasetSplitDefinition(pydantic.main.BaseModel):
23class DatasetSplitDefinition(BaseModel):
24    """
25    A definition of a split in a dataset.
26
27    Example: name="train", description="The training set", percentage=0.8 (80% of the dataset)
28    """
29
30    name: FilenameString = Field(
31        description="The name of the dataset split definition."
32    )
33    description: str | None = Field(
34        default=None,
35        description="A description of the dataset for you and your team. Not used in training.",
36    )
37    percentage: float = Field(
38        ge=0.0,
39        le=1.0,
40        description="The percentage of the dataset that this split represents (between 0 and 1).",
41    )

A definition of a split in a dataset.

Example: name="train", description="The training set", percentage=0.8 (80% of the dataset)

name: Annotated[str, BeforeValidator(func=<function name_validator.<locals>.fn at 0x7fe1ec5165c0>, json_schema_input_type=PydanticUndefined)]
description: str | None
percentage: float
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

AllSplitDefinition: list[DatasetSplitDefinition] = [DatasetSplitDefinition(name='all', description=None, percentage=1.0)]
Train80Test20SplitDefinition: list[DatasetSplitDefinition] = [DatasetSplitDefinition(name='train', description=None, percentage=0.8), DatasetSplitDefinition(name='test', description=None, percentage=0.2)]
Train80Val20SplitDefinition: list[DatasetSplitDefinition] = [DatasetSplitDefinition(name='train', description=None, percentage=0.8), DatasetSplitDefinition(name='val', description=None, percentage=0.2)]
Train60Test20Val20SplitDefinition: list[DatasetSplitDefinition] = [DatasetSplitDefinition(name='train', description=None, percentage=0.6), DatasetSplitDefinition(name='test', description=None, percentage=0.2), DatasetSplitDefinition(name='val', description=None, percentage=0.2)]
Train80Test10Val10SplitDefinition: list[DatasetSplitDefinition] = [DatasetSplitDefinition(name='train', description=None, percentage=0.8), DatasetSplitDefinition(name='test', description=None, percentage=0.1), DatasetSplitDefinition(name='val', description=None, percentage=0.1)]
class DatasetSplit(kiln_ai.datamodel.basemodel.KilnParentedModel):
 67class DatasetSplit(KilnParentedModel):
 68    """
 69    A collection of task runs, with optional splits (train, test, validation).
 70
 71    Used to freeze a dataset into train/test/validation splits for repeatable fine-tuning or other tasks.
 72
 73    Maintains a list of IDs for each split, to avoid data duplication.
 74    """
 75
 76    name: FilenameString = Field(description="The name of the dataset split.")
 77    description: str | None = Field(
 78        default=None,
 79        description="A description of the dataset for you and your team. Not used in training.",
 80    )
 81    splits: list[DatasetSplitDefinition] = Field(
 82        default_factory=list,
 83        description="The splits in the dataset.",
 84    )
 85    split_contents: dict[str, list[str]] = Field(
 86        description="The contents of each split in the dataset. The key is the split name, and the value is a list of task run IDs.",
 87    )
 88    filter: DatasetFilterId | None = Field(
 89        default=None,
 90        description="The filter used to build the dataset.",
 91    )
 92
 93    @model_validator(mode="after")
 94    def validate_split_percentages(self) -> "DatasetSplit":
 95        total = sum(split.percentage for split in self.splits)
 96        if not math.isclose(total, 1.0, rel_tol=1e-9):
 97            raise ValueError(f"The sum of split percentages must be 1.0 (got {total})")
 98        return self
 99
100    @classmethod
101    def from_task(
102        cls,
103        name: str,
104        task: "Task",
105        splits: list[DatasetSplitDefinition],
106        filter_id: DatasetFilterId = "all",
107        description: str | None = None,
108    ):
109        """
110        Build a dataset split from a task.
111        """
112        filter = dataset_filter_from_id(filter_id)
113        split_contents = cls.build_split_contents(task, splits, filter)
114        return cls(
115            parent=task,
116            name=name,
117            description=description,
118            splits=splits,
119            split_contents=split_contents,
120            filter=filter_id,
121        )
122
123    @classmethod
124    def build_split_contents(
125        cls,
126        task: "Task",
127        splits: list[DatasetSplitDefinition],
128        filter: DatasetFilter,
129    ) -> dict[str, list[str]]:
130        valid_ids = []
131        for task_run in task.runs():
132            if filter(task_run):
133                valid_ids.append(task_run.id)
134
135        # Shuffle and split by split percentage
136        random.shuffle(valid_ids)
137        split_contents = {}
138        start_idx = 0
139        remaining_items = len(valid_ids)
140
141        # Handle all splits except the last one
142        for split in splits[:-1]:
143            split_size = round(len(valid_ids) * split.percentage)
144            split_contents[split.name] = valid_ids[start_idx : start_idx + split_size]
145            start_idx += split_size
146            remaining_items -= split_size
147
148        # Last split gets all remaining items (for rounding)
149        if splits:
150            split_contents[splits[-1].name] = valid_ids[start_idx:]
151
152        return split_contents
153
154    def parent_task(self) -> "Task | None":
155        # inline import to avoid circular import
156        from kiln_ai.datamodel import Task
157
158        if not isinstance(self.parent, Task):
159            return None
160        return self.parent
161
162    def missing_count(self) -> int:
163        """
164        Returns:
165            int: the number of task runs that have an ID persisted in this dataset split, but no longer exist in the dataset
166        """
167        parent = self.parent_task()
168        if parent is None:
169            raise ValueError("DatasetSplit has no parent task")
170
171        runs = parent.runs(readonly=True)
172        all_ids = set(run.id for run in runs)
173        all_ids_in_splits = set()
174        for ids in self.split_contents.values():
175            all_ids_in_splits.update(ids)
176        missing = all_ids_in_splits - all_ids
177        return len(missing)

A collection of task runs, with optional splits (train, test, validation).

Used to freeze a dataset into train/test/validation splits for repeatable fine-tuning or other tasks.

Maintains a list of IDs for each split, to avoid data duplication.

name: Annotated[str, BeforeValidator(func=<function name_validator.<locals>.fn at 0x7fe1ec5165c0>, json_schema_input_type=PydanticUndefined)]
description: str | None
splits: list[DatasetSplitDefinition]
split_contents: dict[str, list[str]]
filter: Optional[Annotated[str, AfterValidator(func=<function <lambda> at 0x7fe1ec246ac0>)]]
@model_validator(mode='after')
def validate_split_percentages(self) -> DatasetSplit:
93    @model_validator(mode="after")
94    def validate_split_percentages(self) -> "DatasetSplit":
95        total = sum(split.percentage for split in self.splits)
96        if not math.isclose(total, 1.0, rel_tol=1e-9):
97            raise ValueError(f"The sum of split percentages must be 1.0 (got {total})")
98        return self
@classmethod
def from_task( cls, name: str, task: kiln_ai.datamodel.Task, splits: list[DatasetSplitDefinition], filter_id: Annotated[str, AfterValidator(func=<function <lambda>>)] = 'all', description: str | None = None):
100    @classmethod
101    def from_task(
102        cls,
103        name: str,
104        task: "Task",
105        splits: list[DatasetSplitDefinition],
106        filter_id: DatasetFilterId = "all",
107        description: str | None = None,
108    ):
109        """
110        Build a dataset split from a task.
111        """
112        filter = dataset_filter_from_id(filter_id)
113        split_contents = cls.build_split_contents(task, splits, filter)
114        return cls(
115            parent=task,
116            name=name,
117            description=description,
118            splits=splits,
119            split_contents=split_contents,
120            filter=filter_id,
121        )

Build a dataset split from a task.

@classmethod
def build_split_contents( cls, task: kiln_ai.datamodel.Task, splits: list[DatasetSplitDefinition], filter: kiln_ai.datamodel.dataset_filters.DatasetFilter) -> dict[str, list[str]]:
123    @classmethod
124    def build_split_contents(
125        cls,
126        task: "Task",
127        splits: list[DatasetSplitDefinition],
128        filter: DatasetFilter,
129    ) -> dict[str, list[str]]:
130        valid_ids = []
131        for task_run in task.runs():
132            if filter(task_run):
133                valid_ids.append(task_run.id)
134
135        # Shuffle and split by split percentage
136        random.shuffle(valid_ids)
137        split_contents = {}
138        start_idx = 0
139        remaining_items = len(valid_ids)
140
141        # Handle all splits except the last one
142        for split in splits[:-1]:
143            split_size = round(len(valid_ids) * split.percentage)
144            split_contents[split.name] = valid_ids[start_idx : start_idx + split_size]
145            start_idx += split_size
146            remaining_items -= split_size
147
148        # Last split gets all remaining items (for rounding)
149        if splits:
150            split_contents[splits[-1].name] = valid_ids[start_idx:]
151
152        return split_contents
def parent_task(self) -> kiln_ai.datamodel.Task | None:
154    def parent_task(self) -> "Task | None":
155        # inline import to avoid circular import
156        from kiln_ai.datamodel import Task
157
158        if not isinstance(self.parent, Task):
159            return None
160        return self.parent
def missing_count(self) -> int:
162    def missing_count(self) -> int:
163        """
164        Returns:
165            int: the number of task runs that have an ID persisted in this dataset split, but no longer exist in the dataset
166        """
167        parent = self.parent_task()
168        if parent is None:
169            raise ValueError("DatasetSplit has no parent task")
170
171        runs = parent.runs(readonly=True)
172        all_ids = set(run.id for run in runs)
173        all_ids_in_splits = set()
174        for ids in self.split_contents.values():
175            all_ids_in_splits.update(ids)
176        missing = all_ids_in_splits - all_ids
177        return len(missing)

Returns: int: the number of task runs that have an ID persisted in this dataset split, but no longer exist in the dataset

def relationship_name() -> str:
464        def relationship_name_method() -> str:
465            return relationship_name

The type of the None singleton.

def parent_type() -> Type[kiln_ai.datamodel.basemodel.KilnParentModel]:
457        def parent_class_method() -> Type[KilnParentModel]:
458            return cls

The type of the None singleton.

model_config = {'validate_assignment': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

def model_post_init(self: pydantic.main.BaseModel, context: Any, /) -> None:
122                    def wrapped_model_post_init(self: BaseModel, context: Any, /) -> None:
123                        """We need to both initialize private attributes and call the user-defined model_post_init
124                        method.
125                        """
126                        init_private_attributes(self, context)
127                        original_model_post_init(self, context)

We need to both initialize private attributes and call the user-defined model_post_init method.