kiln_ai.datamodel.dataset_split

Tools for splitting datasets into train/test/validation splits. Includes filters for selecting which task runs to include in each split.

  1"""
  2Tools for splitting datasets into train/test/validation splits. Includes filters for selecting which task runs to include in each split.
  3"""
  4
  5import math
  6import random
  7from typing import TYPE_CHECKING
  8
  9from pydantic import BaseModel, Field, model_validator
 10
 11from kiln_ai.datamodel.basemodel import NAME_FIELD, KilnParentedModel
 12from kiln_ai.datamodel.dataset_filters import (
 13    DatasetFilter,
 14    DatasetFilterId,
 15    dataset_filter_from_id,
 16)
 17
 18if TYPE_CHECKING:
 19    from kiln_ai.datamodel.task import Task
 20
 21
 22class DatasetSplitDefinition(BaseModel):
 23    """
 24    A definition of a split in a dataset.
 25
 26    Example: name="train", description="The training set", percentage=0.8 (80% of the dataset)
 27    """
 28
 29    name: str = NAME_FIELD
 30    description: str | None = Field(
 31        default=None,
 32        description="A description of the dataset for you and your team. Not used in training.",
 33    )
 34    percentage: float = Field(
 35        ge=0.0,
 36        le=1.0,
 37        description="The percentage of the dataset that this split represents (between 0 and 1).",
 38    )
 39
 40
 41AllSplitDefinition: list[DatasetSplitDefinition] = [
 42    DatasetSplitDefinition(name="all", percentage=1.0)
 43]
 44Train80Test20SplitDefinition: list[DatasetSplitDefinition] = [
 45    DatasetSplitDefinition(name="train", percentage=0.8),
 46    DatasetSplitDefinition(name="test", percentage=0.2),
 47]
 48Train80Val20SplitDefinition: list[DatasetSplitDefinition] = [
 49    DatasetSplitDefinition(name="train", percentage=0.8),
 50    DatasetSplitDefinition(name="val", percentage=0.2),
 51]
 52Train60Test20Val20SplitDefinition: list[DatasetSplitDefinition] = [
 53    DatasetSplitDefinition(name="train", percentage=0.6),
 54    DatasetSplitDefinition(name="test", percentage=0.2),
 55    DatasetSplitDefinition(name="val", percentage=0.2),
 56]
 57Train80Test10Val10SplitDefinition: list[DatasetSplitDefinition] = [
 58    DatasetSplitDefinition(name="train", percentage=0.8),
 59    DatasetSplitDefinition(name="test", percentage=0.1),
 60    DatasetSplitDefinition(name="val", percentage=0.1),
 61]
 62
 63
 64class DatasetSplit(KilnParentedModel):
 65    """
 66    A collection of task runs, with optional splits (train, test, validation).
 67
 68    Used to freeze a dataset into train/test/validation splits for repeatable fine-tuning or other tasks.
 69
 70    Maintains a list of IDs for each split, to avoid data duplication.
 71    """
 72
 73    name: str = NAME_FIELD
 74    description: str | None = Field(
 75        default=None,
 76        description="A description of the dataset for you and your team. Not used in training.",
 77    )
 78    splits: list[DatasetSplitDefinition] = Field(
 79        default_factory=list,
 80        description="The splits in the dataset.",
 81    )
 82    split_contents: dict[str, list[str]] = Field(
 83        description="The contents of each split in the dataset. The key is the split name, and the value is a list of task run IDs.",
 84    )
 85    filter: DatasetFilterId | None = Field(
 86        default=None,
 87        description="The filter used to build the dataset.",
 88    )
 89
 90    @model_validator(mode="after")
 91    def validate_split_percentages(self) -> "DatasetSplit":
 92        total = sum(split.percentage for split in self.splits)
 93        if not math.isclose(total, 1.0, rel_tol=1e-9):
 94            raise ValueError(f"The sum of split percentages must be 1.0 (got {total})")
 95        return self
 96
 97    @classmethod
 98    def from_task(
 99        cls,
100        name: str,
101        task: "Task",
102        splits: list[DatasetSplitDefinition],
103        filter_id: DatasetFilterId = "all",
104        description: str | None = None,
105    ):
106        """
107        Build a dataset split from a task.
108        """
109        filter = dataset_filter_from_id(filter_id)
110        split_contents = cls.build_split_contents(task, splits, filter)
111        return cls(
112            parent=task,
113            name=name,
114            description=description,
115            splits=splits,
116            split_contents=split_contents,
117            filter=filter_id,
118        )
119
120    @classmethod
121    def build_split_contents(
122        cls,
123        task: "Task",
124        splits: list[DatasetSplitDefinition],
125        filter: DatasetFilter,
126    ) -> dict[str, list[str]]:
127        valid_ids = []
128        for task_run in task.runs():
129            if filter(task_run):
130                valid_ids.append(task_run.id)
131
132        # Shuffle and split by split percentage
133        random.shuffle(valid_ids)
134        split_contents = {}
135        start_idx = 0
136        remaining_items = len(valid_ids)
137
138        # Handle all splits except the last one
139        for split in splits[:-1]:
140            split_size = round(len(valid_ids) * split.percentage)
141            split_contents[split.name] = valid_ids[start_idx : start_idx + split_size]
142            start_idx += split_size
143            remaining_items -= split_size
144
145        # Last split gets all remaining items (for rounding)
146        if splits:
147            split_contents[splits[-1].name] = valid_ids[start_idx:]
148
149        return split_contents
150
151    def parent_task(self) -> "Task | None":
152        # inline import to avoid circular import
153        from kiln_ai.datamodel import Task
154
155        if not isinstance(self.parent, Task):
156            return None
157        return self.parent
158
159    def missing_count(self) -> int:
160        """
161        Returns:
162            int: the number of task runs that have an ID persisted in this dataset split, but no longer exist in the dataset
163        """
164        parent = self.parent_task()
165        if parent is None:
166            raise ValueError("DatasetSplit has no parent task")
167
168        runs = parent.runs(readonly=True)
169        all_ids = set(run.id for run in runs)
170        all_ids_in_splits = set()
171        for ids in self.split_contents.values():
172            all_ids_in_splits.update(ids)
173        missing = all_ids_in_splits - all_ids
174        return len(missing)
class DatasetSplitDefinition(pydantic.main.BaseModel):
23class DatasetSplitDefinition(BaseModel):
24    """
25    A definition of a split in a dataset.
26
27    Example: name="train", description="The training set", percentage=0.8 (80% of the dataset)
28    """
29
30    name: str = NAME_FIELD
31    description: str | None = Field(
32        default=None,
33        description="A description of the dataset for you and your team. Not used in training.",
34    )
35    percentage: float = Field(
36        ge=0.0,
37        le=1.0,
38        description="The percentage of the dataset that this split represents (between 0 and 1).",
39    )

A definition of a split in a dataset.

Example: name="train", description="The training set", percentage=0.8 (80% of the dataset)

name: str
description: str | None
percentage: float
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

AllSplitDefinition: list[DatasetSplitDefinition] = [DatasetSplitDefinition(name='all', description=None, percentage=1.0)]
Train80Test20SplitDefinition: list[DatasetSplitDefinition] = [DatasetSplitDefinition(name='train', description=None, percentage=0.8), DatasetSplitDefinition(name='test', description=None, percentage=0.2)]
Train80Val20SplitDefinition: list[DatasetSplitDefinition] = [DatasetSplitDefinition(name='train', description=None, percentage=0.8), DatasetSplitDefinition(name='val', description=None, percentage=0.2)]
Train60Test20Val20SplitDefinition: list[DatasetSplitDefinition] = [DatasetSplitDefinition(name='train', description=None, percentage=0.6), DatasetSplitDefinition(name='test', description=None, percentage=0.2), DatasetSplitDefinition(name='val', description=None, percentage=0.2)]
Train80Test10Val10SplitDefinition: list[DatasetSplitDefinition] = [DatasetSplitDefinition(name='train', description=None, percentage=0.8), DatasetSplitDefinition(name='test', description=None, percentage=0.1), DatasetSplitDefinition(name='val', description=None, percentage=0.1)]
class DatasetSplit(kiln_ai.datamodel.basemodel.KilnParentedModel):
 65class DatasetSplit(KilnParentedModel):
 66    """
 67    A collection of task runs, with optional splits (train, test, validation).
 68
 69    Used to freeze a dataset into train/test/validation splits for repeatable fine-tuning or other tasks.
 70
 71    Maintains a list of IDs for each split, to avoid data duplication.
 72    """
 73
 74    name: str = NAME_FIELD
 75    description: str | None = Field(
 76        default=None,
 77        description="A description of the dataset for you and your team. Not used in training.",
 78    )
 79    splits: list[DatasetSplitDefinition] = Field(
 80        default_factory=list,
 81        description="The splits in the dataset.",
 82    )
 83    split_contents: dict[str, list[str]] = Field(
 84        description="The contents of each split in the dataset. The key is the split name, and the value is a list of task run IDs.",
 85    )
 86    filter: DatasetFilterId | None = Field(
 87        default=None,
 88        description="The filter used to build the dataset.",
 89    )
 90
 91    @model_validator(mode="after")
 92    def validate_split_percentages(self) -> "DatasetSplit":
 93        total = sum(split.percentage for split in self.splits)
 94        if not math.isclose(total, 1.0, rel_tol=1e-9):
 95            raise ValueError(f"The sum of split percentages must be 1.0 (got {total})")
 96        return self
 97
 98    @classmethod
 99    def from_task(
100        cls,
101        name: str,
102        task: "Task",
103        splits: list[DatasetSplitDefinition],
104        filter_id: DatasetFilterId = "all",
105        description: str | None = None,
106    ):
107        """
108        Build a dataset split from a task.
109        """
110        filter = dataset_filter_from_id(filter_id)
111        split_contents = cls.build_split_contents(task, splits, filter)
112        return cls(
113            parent=task,
114            name=name,
115            description=description,
116            splits=splits,
117            split_contents=split_contents,
118            filter=filter_id,
119        )
120
121    @classmethod
122    def build_split_contents(
123        cls,
124        task: "Task",
125        splits: list[DatasetSplitDefinition],
126        filter: DatasetFilter,
127    ) -> dict[str, list[str]]:
128        valid_ids = []
129        for task_run in task.runs():
130            if filter(task_run):
131                valid_ids.append(task_run.id)
132
133        # Shuffle and split by split percentage
134        random.shuffle(valid_ids)
135        split_contents = {}
136        start_idx = 0
137        remaining_items = len(valid_ids)
138
139        # Handle all splits except the last one
140        for split in splits[:-1]:
141            split_size = round(len(valid_ids) * split.percentage)
142            split_contents[split.name] = valid_ids[start_idx : start_idx + split_size]
143            start_idx += split_size
144            remaining_items -= split_size
145
146        # Last split gets all remaining items (for rounding)
147        if splits:
148            split_contents[splits[-1].name] = valid_ids[start_idx:]
149
150        return split_contents
151
152    def parent_task(self) -> "Task | None":
153        # inline import to avoid circular import
154        from kiln_ai.datamodel import Task
155
156        if not isinstance(self.parent, Task):
157            return None
158        return self.parent
159
160    def missing_count(self) -> int:
161        """
162        Returns:
163            int: the number of task runs that have an ID persisted in this dataset split, but no longer exist in the dataset
164        """
165        parent = self.parent_task()
166        if parent is None:
167            raise ValueError("DatasetSplit has no parent task")
168
169        runs = parent.runs(readonly=True)
170        all_ids = set(run.id for run in runs)
171        all_ids_in_splits = set()
172        for ids in self.split_contents.values():
173            all_ids_in_splits.update(ids)
174        missing = all_ids_in_splits - all_ids
175        return len(missing)

A collection of task runs, with optional splits (train, test, validation).

Used to freeze a dataset into train/test/validation splits for repeatable fine-tuning or other tasks.

Maintains a list of IDs for each split, to avoid data duplication.

name: str
description: str | None
splits: list[DatasetSplitDefinition]
split_contents: dict[str, list[str]]
filter: Optional[Annotated[str, AfterValidator(func=<function <lambda> at 0x7f5d86ca58a0>)]]
@model_validator(mode='after')
def validate_split_percentages(self) -> DatasetSplit:
91    @model_validator(mode="after")
92    def validate_split_percentages(self) -> "DatasetSplit":
93        total = sum(split.percentage for split in self.splits)
94        if not math.isclose(total, 1.0, rel_tol=1e-9):
95            raise ValueError(f"The sum of split percentages must be 1.0 (got {total})")
96        return self
@classmethod
def from_task( cls, name: str, task: kiln_ai.datamodel.Task, splits: list[DatasetSplitDefinition], filter_id: Annotated[str, AfterValidator(func=<function <lambda>>)] = 'all', description: str | None = None):
 98    @classmethod
 99    def from_task(
100        cls,
101        name: str,
102        task: "Task",
103        splits: list[DatasetSplitDefinition],
104        filter_id: DatasetFilterId = "all",
105        description: str | None = None,
106    ):
107        """
108        Build a dataset split from a task.
109        """
110        filter = dataset_filter_from_id(filter_id)
111        split_contents = cls.build_split_contents(task, splits, filter)
112        return cls(
113            parent=task,
114            name=name,
115            description=description,
116            splits=splits,
117            split_contents=split_contents,
118            filter=filter_id,
119        )

Build a dataset split from a task.

@classmethod
def build_split_contents( cls, task: kiln_ai.datamodel.Task, splits: list[DatasetSplitDefinition], filter: kiln_ai.datamodel.dataset_filters.DatasetFilter) -> dict[str, list[str]]:
121    @classmethod
122    def build_split_contents(
123        cls,
124        task: "Task",
125        splits: list[DatasetSplitDefinition],
126        filter: DatasetFilter,
127    ) -> dict[str, list[str]]:
128        valid_ids = []
129        for task_run in task.runs():
130            if filter(task_run):
131                valid_ids.append(task_run.id)
132
133        # Shuffle and split by split percentage
134        random.shuffle(valid_ids)
135        split_contents = {}
136        start_idx = 0
137        remaining_items = len(valid_ids)
138
139        # Handle all splits except the last one
140        for split in splits[:-1]:
141            split_size = round(len(valid_ids) * split.percentage)
142            split_contents[split.name] = valid_ids[start_idx : start_idx + split_size]
143            start_idx += split_size
144            remaining_items -= split_size
145
146        # Last split gets all remaining items (for rounding)
147        if splits:
148            split_contents[splits[-1].name] = valid_ids[start_idx:]
149
150        return split_contents
def parent_task(self) -> kiln_ai.datamodel.Task | None:
152    def parent_task(self) -> "Task | None":
153        # inline import to avoid circular import
154        from kiln_ai.datamodel import Task
155
156        if not isinstance(self.parent, Task):
157            return None
158        return self.parent
def missing_count(self) -> int:
160    def missing_count(self) -> int:
161        """
162        Returns:
163            int: the number of task runs that have an ID persisted in this dataset split, but no longer exist in the dataset
164        """
165        parent = self.parent_task()
166        if parent is None:
167            raise ValueError("DatasetSplit has no parent task")
168
169        runs = parent.runs(readonly=True)
170        all_ids = set(run.id for run in runs)
171        all_ids_in_splits = set()
172        for ids in self.split_contents.values():
173            all_ids_in_splits.update(ids)
174        missing = all_ids_in_splits - all_ids
175        return len(missing)

Returns: int: the number of task runs that have an ID persisted in this dataset split, but no longer exist in the dataset

def relationship_name() -> str:
438        def relationship_name_method() -> str:
439            return relationship_name

The type of the None singleton.

def parent_type() -> Type[kiln_ai.datamodel.basemodel.KilnParentModel]:
431        def parent_class_method() -> Type[KilnParentModel]:
432            return cls

The type of the None singleton.

model_config = {'validate_assignment': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

def model_post_init(self: pydantic.main.BaseModel, context: Any, /) -> None:
122                    def wrapped_model_post_init(self: BaseModel, context: Any, /) -> None:
123                        """We need to both initialize private attributes and call the user-defined model_post_init
124                        method.
125                        """
126                        init_private_attributes(self, context)
127                        original_model_post_init(self, context)

We need to both initialize private attributes and call the user-defined model_post_init method.