kiln_ai.datamodel.dataset_split

Tools for splitting datasets into train/test/validation splits. Includes filters for selecting which task runs to include in each split.

  1"""
  2Tools for splitting datasets into train/test/validation splits. Includes filters for selecting which task runs to include in each split.
  3"""
  4
  5import math
  6import random
  7from typing import TYPE_CHECKING
  8
  9from pydantic import BaseModel, Field, model_validator
 10
 11from kiln_ai.datamodel.basemodel import NAME_FIELD, KilnParentedModel
 12from kiln_ai.datamodel.dataset_filters import (
 13    DatasetFilter,
 14    DatasetFilterId,
 15    dataset_filter_from_id,
 16)
 17
 18if TYPE_CHECKING:
 19    from kiln_ai.datamodel.task import Task
 20
 21
 22class DatasetSplitDefinition(BaseModel):
 23    """
 24    A definition of a split in a dataset.
 25
 26    Example: name="train", description="The training set", percentage=0.8 (80% of the dataset)
 27    """
 28
 29    name: str = NAME_FIELD
 30    description: str | None = Field(
 31        default=None,
 32        description="A description of the dataset for you and your team. Not used in training.",
 33    )
 34    percentage: float = Field(
 35        ge=0.0,
 36        le=1.0,
 37        description="The percentage of the dataset that this split represents (between 0 and 1).",
 38    )
 39
 40
 41AllSplitDefinition: list[DatasetSplitDefinition] = [
 42    DatasetSplitDefinition(name="all", percentage=1.0)
 43]
 44Train80Test20SplitDefinition: list[DatasetSplitDefinition] = [
 45    DatasetSplitDefinition(name="train", percentage=0.8),
 46    DatasetSplitDefinition(name="test", percentage=0.2),
 47]
 48Train60Test20Val20SplitDefinition: list[DatasetSplitDefinition] = [
 49    DatasetSplitDefinition(name="train", percentage=0.6),
 50    DatasetSplitDefinition(name="test", percentage=0.2),
 51    DatasetSplitDefinition(name="val", percentage=0.2),
 52]
 53Train80Test10Val10SplitDefinition: list[DatasetSplitDefinition] = [
 54    DatasetSplitDefinition(name="train", percentage=0.8),
 55    DatasetSplitDefinition(name="test", percentage=0.1),
 56    DatasetSplitDefinition(name="val", percentage=0.1),
 57]
 58
 59
 60class DatasetSplit(KilnParentedModel):
 61    """
 62    A collection of task runs, with optional splits (train, test, validation).
 63
 64    Used to freeze a dataset into train/test/validation splits for repeatable fine-tuning or other tasks.
 65
 66    Maintains a list of IDs for each split, to avoid data duplication.
 67    """
 68
 69    name: str = NAME_FIELD
 70    description: str | None = Field(
 71        default=None,
 72        description="A description of the dataset for you and your team. Not used in training.",
 73    )
 74    splits: list[DatasetSplitDefinition] = Field(
 75        default_factory=list,
 76        description="The splits in the dataset.",
 77    )
 78    split_contents: dict[str, list[str]] = Field(
 79        description="The contents of each split in the dataset. The key is the split name, and the value is a list of task run IDs.",
 80    )
 81    filter: DatasetFilterId | None = Field(
 82        default=None,
 83        description="The filter used to build the dataset.",
 84    )
 85
 86    @model_validator(mode="after")
 87    def validate_split_percentages(self) -> "DatasetSplit":
 88        total = sum(split.percentage for split in self.splits)
 89        if not math.isclose(total, 1.0, rel_tol=1e-9):
 90            raise ValueError(f"The sum of split percentages must be 1.0 (got {total})")
 91        return self
 92
 93    @classmethod
 94    def from_task(
 95        cls,
 96        name: str,
 97        task: "Task",
 98        splits: list[DatasetSplitDefinition],
 99        filter_id: DatasetFilterId = "all",
100        description: str | None = None,
101    ):
102        """
103        Build a dataset split from a task.
104        """
105        filter = dataset_filter_from_id(filter_id)
106        split_contents = cls.build_split_contents(task, splits, filter)
107        return cls(
108            parent=task,
109            name=name,
110            description=description,
111            splits=splits,
112            split_contents=split_contents,
113            filter=filter_id,
114        )
115
116    @classmethod
117    def build_split_contents(
118        cls,
119        task: "Task",
120        splits: list[DatasetSplitDefinition],
121        filter: DatasetFilter,
122    ) -> dict[str, list[str]]:
123        valid_ids = []
124        for task_run in task.runs():
125            if filter(task_run):
126                valid_ids.append(task_run.id)
127
128        # Shuffle and split by split percentage
129        random.shuffle(valid_ids)
130        split_contents = {}
131        start_idx = 0
132        remaining_items = len(valid_ids)
133
134        # Handle all splits except the last one
135        for split in splits[:-1]:
136            split_size = round(len(valid_ids) * split.percentage)
137            split_contents[split.name] = valid_ids[start_idx : start_idx + split_size]
138            start_idx += split_size
139            remaining_items -= split_size
140
141        # Last split gets all remaining items (for rounding)
142        if splits:
143            split_contents[splits[-1].name] = valid_ids[start_idx:]
144
145        return split_contents
146
147    def parent_task(self) -> "Task | None":
148        # inline import to avoid circular import
149        from kiln_ai.datamodel import Task
150
151        if not isinstance(self.parent, Task):
152            return None
153        return self.parent
154
155    def missing_count(self) -> int:
156        """
157        Returns:
158            int: the number of task runs that have an ID persisted in this dataset split, but no longer exist in the dataset
159        """
160        parent = self.parent_task()
161        if parent is None:
162            raise ValueError("DatasetSplit has no parent task")
163
164        runs = parent.runs(readonly=True)
165        all_ids = set(run.id for run in runs)
166        all_ids_in_splits = set()
167        for ids in self.split_contents.values():
168            all_ids_in_splits.update(ids)
169        missing = all_ids_in_splits - all_ids
170        return len(missing)
class DatasetSplitDefinition(pydantic.main.BaseModel):
23class DatasetSplitDefinition(BaseModel):
24    """
25    A definition of a split in a dataset.
26
27    Example: name="train", description="The training set", percentage=0.8 (80% of the dataset)
28    """
29
30    name: str = NAME_FIELD
31    description: str | None = Field(
32        default=None,
33        description="A description of the dataset for you and your team. Not used in training.",
34    )
35    percentage: float = Field(
36        ge=0.0,
37        le=1.0,
38        description="The percentage of the dataset that this split represents (between 0 and 1).",
39    )

A definition of a split in a dataset.

Example: name="train", description="The training set", percentage=0.8 (80% of the dataset)

name: str
description: str | None
percentage: float
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

AllSplitDefinition: list[DatasetSplitDefinition] = [DatasetSplitDefinition(name='all', description=None, percentage=1.0)]
Train80Test20SplitDefinition: list[DatasetSplitDefinition] = [DatasetSplitDefinition(name='train', description=None, percentage=0.8), DatasetSplitDefinition(name='test', description=None, percentage=0.2)]
Train60Test20Val20SplitDefinition: list[DatasetSplitDefinition] = [DatasetSplitDefinition(name='train', description=None, percentage=0.6), DatasetSplitDefinition(name='test', description=None, percentage=0.2), DatasetSplitDefinition(name='val', description=None, percentage=0.2)]
Train80Test10Val10SplitDefinition: list[DatasetSplitDefinition] = [DatasetSplitDefinition(name='train', description=None, percentage=0.8), DatasetSplitDefinition(name='test', description=None, percentage=0.1), DatasetSplitDefinition(name='val', description=None, percentage=0.1)]
class DatasetSplit(kiln_ai.datamodel.basemodel.KilnParentedModel):
 61class DatasetSplit(KilnParentedModel):
 62    """
 63    A collection of task runs, with optional splits (train, test, validation).
 64
 65    Used to freeze a dataset into train/test/validation splits for repeatable fine-tuning or other tasks.
 66
 67    Maintains a list of IDs for each split, to avoid data duplication.
 68    """
 69
 70    name: str = NAME_FIELD
 71    description: str | None = Field(
 72        default=None,
 73        description="A description of the dataset for you and your team. Not used in training.",
 74    )
 75    splits: list[DatasetSplitDefinition] = Field(
 76        default_factory=list,
 77        description="The splits in the dataset.",
 78    )
 79    split_contents: dict[str, list[str]] = Field(
 80        description="The contents of each split in the dataset. The key is the split name, and the value is a list of task run IDs.",
 81    )
 82    filter: DatasetFilterId | None = Field(
 83        default=None,
 84        description="The filter used to build the dataset.",
 85    )
 86
 87    @model_validator(mode="after")
 88    def validate_split_percentages(self) -> "DatasetSplit":
 89        total = sum(split.percentage for split in self.splits)
 90        if not math.isclose(total, 1.0, rel_tol=1e-9):
 91            raise ValueError(f"The sum of split percentages must be 1.0 (got {total})")
 92        return self
 93
 94    @classmethod
 95    def from_task(
 96        cls,
 97        name: str,
 98        task: "Task",
 99        splits: list[DatasetSplitDefinition],
100        filter_id: DatasetFilterId = "all",
101        description: str | None = None,
102    ):
103        """
104        Build a dataset split from a task.
105        """
106        filter = dataset_filter_from_id(filter_id)
107        split_contents = cls.build_split_contents(task, splits, filter)
108        return cls(
109            parent=task,
110            name=name,
111            description=description,
112            splits=splits,
113            split_contents=split_contents,
114            filter=filter_id,
115        )
116
117    @classmethod
118    def build_split_contents(
119        cls,
120        task: "Task",
121        splits: list[DatasetSplitDefinition],
122        filter: DatasetFilter,
123    ) -> dict[str, list[str]]:
124        valid_ids = []
125        for task_run in task.runs():
126            if filter(task_run):
127                valid_ids.append(task_run.id)
128
129        # Shuffle and split by split percentage
130        random.shuffle(valid_ids)
131        split_contents = {}
132        start_idx = 0
133        remaining_items = len(valid_ids)
134
135        # Handle all splits except the last one
136        for split in splits[:-1]:
137            split_size = round(len(valid_ids) * split.percentage)
138            split_contents[split.name] = valid_ids[start_idx : start_idx + split_size]
139            start_idx += split_size
140            remaining_items -= split_size
141
142        # Last split gets all remaining items (for rounding)
143        if splits:
144            split_contents[splits[-1].name] = valid_ids[start_idx:]
145
146        return split_contents
147
148    def parent_task(self) -> "Task | None":
149        # inline import to avoid circular import
150        from kiln_ai.datamodel import Task
151
152        if not isinstance(self.parent, Task):
153            return None
154        return self.parent
155
156    def missing_count(self) -> int:
157        """
158        Returns:
159            int: the number of task runs that have an ID persisted in this dataset split, but no longer exist in the dataset
160        """
161        parent = self.parent_task()
162        if parent is None:
163            raise ValueError("DatasetSplit has no parent task")
164
165        runs = parent.runs(readonly=True)
166        all_ids = set(run.id for run in runs)
167        all_ids_in_splits = set()
168        for ids in self.split_contents.values():
169            all_ids_in_splits.update(ids)
170        missing = all_ids_in_splits - all_ids
171        return len(missing)

A collection of task runs, with optional splits (train, test, validation).

Used to freeze a dataset into train/test/validation splits for repeatable fine-tuning or other tasks.

Maintains a list of IDs for each split, to avoid data duplication.

name: str
description: str | None
splits: list[DatasetSplitDefinition]
split_contents: dict[str, list[str]]
filter: Optional[Annotated[str, AfterValidator(func=<function <lambda> at 0x7fcb133b3ba0>)]]
@model_validator(mode='after')
def validate_split_percentages(self) -> DatasetSplit:
87    @model_validator(mode="after")
88    def validate_split_percentages(self) -> "DatasetSplit":
89        total = sum(split.percentage for split in self.splits)
90        if not math.isclose(total, 1.0, rel_tol=1e-9):
91            raise ValueError(f"The sum of split percentages must be 1.0 (got {total})")
92        return self
@classmethod
def from_task( cls, name: str, task: kiln_ai.datamodel.Task, splits: list[DatasetSplitDefinition], filter_id: Annotated[str, AfterValidator(func=<function <lambda>>)] = 'all', description: str | None = None):
 94    @classmethod
 95    def from_task(
 96        cls,
 97        name: str,
 98        task: "Task",
 99        splits: list[DatasetSplitDefinition],
100        filter_id: DatasetFilterId = "all",
101        description: str | None = None,
102    ):
103        """
104        Build a dataset split from a task.
105        """
106        filter = dataset_filter_from_id(filter_id)
107        split_contents = cls.build_split_contents(task, splits, filter)
108        return cls(
109            parent=task,
110            name=name,
111            description=description,
112            splits=splits,
113            split_contents=split_contents,
114            filter=filter_id,
115        )

Build a dataset split from a task.

@classmethod
def build_split_contents( cls, task: kiln_ai.datamodel.Task, splits: list[DatasetSplitDefinition], filter: kiln_ai.datamodel.dataset_filters.DatasetFilter) -> dict[str, list[str]]:
117    @classmethod
118    def build_split_contents(
119        cls,
120        task: "Task",
121        splits: list[DatasetSplitDefinition],
122        filter: DatasetFilter,
123    ) -> dict[str, list[str]]:
124        valid_ids = []
125        for task_run in task.runs():
126            if filter(task_run):
127                valid_ids.append(task_run.id)
128
129        # Shuffle and split by split percentage
130        random.shuffle(valid_ids)
131        split_contents = {}
132        start_idx = 0
133        remaining_items = len(valid_ids)
134
135        # Handle all splits except the last one
136        for split in splits[:-1]:
137            split_size = round(len(valid_ids) * split.percentage)
138            split_contents[split.name] = valid_ids[start_idx : start_idx + split_size]
139            start_idx += split_size
140            remaining_items -= split_size
141
142        # Last split gets all remaining items (for rounding)
143        if splits:
144            split_contents[splits[-1].name] = valid_ids[start_idx:]
145
146        return split_contents
def parent_task(self) -> kiln_ai.datamodel.Task | None:
148    def parent_task(self) -> "Task | None":
149        # inline import to avoid circular import
150        from kiln_ai.datamodel import Task
151
152        if not isinstance(self.parent, Task):
153            return None
154        return self.parent
def missing_count(self) -> int:
156    def missing_count(self) -> int:
157        """
158        Returns:
159            int: the number of task runs that have an ID persisted in this dataset split, but no longer exist in the dataset
160        """
161        parent = self.parent_task()
162        if parent is None:
163            raise ValueError("DatasetSplit has no parent task")
164
165        runs = parent.runs(readonly=True)
166        all_ids = set(run.id for run in runs)
167        all_ids_in_splits = set()
168        for ids in self.split_contents.values():
169            all_ids_in_splits.update(ids)
170        missing = all_ids_in_splits - all_ids
171        return len(missing)

Returns: int: the number of task runs that have an ID persisted in this dataset split, but no longer exist in the dataset

def relationship_name() -> str:
438        def relationship_name_method() -> str:
439            return relationship_name

The type of the None singleton.

def parent_type() -> Type[kiln_ai.datamodel.basemodel.KilnParentModel]:
431        def parent_class_method() -> Type[KilnParentModel]:
432            return cls

The type of the None singleton.

model_config = {'validate_assignment': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

def model_post_init(self: pydantic.main.BaseModel, context: Any, /) -> None:
122                    def wrapped_model_post_init(self: BaseModel, context: Any, /) -> None:
123                        """We need to both initialize private attributes and call the user-defined model_post_init
124                        method.
125                        """
126                        init_private_attributes(self, context)
127                        original_model_post_init(self, context)

We need to both initialize private attributes and call the user-defined model_post_init method.