kiln_ai.datamodel.dataset_split

Tools for splitting datasets into train/test/validation splits. Includes filters for selecting which task runs to include in each split.

  1"""
  2Tools for splitting datasets into train/test/validation splits. Includes filters for selecting which task runs to include in each split.
  3"""
  4
  5import math
  6import random
  7from typing import TYPE_CHECKING
  8
  9from pydantic import BaseModel, Field, model_validator
 10
 11from kiln_ai.datamodel.basemodel import FilenameString, KilnParentedModel
 12from kiln_ai.datamodel.dataset_filters import (
 13    DatasetFilter,
 14    DatasetFilterId,
 15    dataset_filter_from_id,
 16)
 17from kiln_ai.datamodel.task_run import TaskRun
 18
 19if TYPE_CHECKING:
 20    from kiln_ai.datamodel.task import Task
 21
 22
 23class DatasetToolInfo(BaseModel):
 24    """
 25    Information about tools used across task runs in a dataset split.
 26    """
 27
 28    has_tool_mismatch: bool = Field(
 29        description="Whether the tools from each run match across all runs in the dataset split."
 30    )
 31    tools: list[str] = Field(
 32        default_factory=list,
 33        description="Common tool IDs shared by every run; empty when tools are mismatched or no tools exist.",
 34    )
 35
 36
 37class DatasetSplitDefinition(BaseModel):
 38    """
 39    A definition of a split in a dataset.
 40
 41    Example: name="train", description="The training set", percentage=0.8 (80% of the dataset)
 42    """
 43
 44    name: FilenameString = Field(
 45        description="The name of the dataset split definition."
 46    )
 47    description: str | None = Field(
 48        default=None,
 49        description="A description of the dataset for you and your team. Not used in training.",
 50    )
 51    percentage: float = Field(
 52        ge=0.0,
 53        le=1.0,
 54        description="The percentage of the dataset that this split represents (between 0 and 1).",
 55    )
 56
 57
 58AllSplitDefinition: list[DatasetSplitDefinition] = [
 59    DatasetSplitDefinition(name="all", percentage=1.0)
 60]
 61Train80Test20SplitDefinition: list[DatasetSplitDefinition] = [
 62    DatasetSplitDefinition(name="train", percentage=0.8),
 63    DatasetSplitDefinition(name="test", percentage=0.2),
 64]
 65Train80Val20SplitDefinition: list[DatasetSplitDefinition] = [
 66    DatasetSplitDefinition(name="train", percentage=0.8),
 67    DatasetSplitDefinition(name="val", percentage=0.2),
 68]
 69Train60Test20Val20SplitDefinition: list[DatasetSplitDefinition] = [
 70    DatasetSplitDefinition(name="train", percentage=0.6),
 71    DatasetSplitDefinition(name="test", percentage=0.2),
 72    DatasetSplitDefinition(name="val", percentage=0.2),
 73]
 74Train80Test10Val10SplitDefinition: list[DatasetSplitDefinition] = [
 75    DatasetSplitDefinition(name="train", percentage=0.8),
 76    DatasetSplitDefinition(name="test", percentage=0.1),
 77    DatasetSplitDefinition(name="val", percentage=0.1),
 78]
 79
 80
 81class DatasetSplit(KilnParentedModel):
 82    """
 83    A collection of task runs, with optional splits (train, test, validation).
 84
 85    Used to freeze a dataset into train/test/validation splits for repeatable fine-tuning or other tasks.
 86
 87    Maintains a list of IDs for each split, to avoid data duplication.
 88    """
 89
 90    name: FilenameString = Field(description="The name of the dataset split.")
 91    description: str | None = Field(
 92        default=None,
 93        description="A description of the dataset for you and your team. Not used in training.",
 94    )
 95    splits: list[DatasetSplitDefinition] = Field(
 96        default_factory=list,
 97        description="The splits in the dataset.",
 98    )
 99    split_contents: dict[str, list[str]] = Field(
100        description="The contents of each split in the dataset. The key is the split name, and the value is a list of task run IDs.",
101    )
102    filter: DatasetFilterId | None = Field(
103        default=None,
104        description="The filter used to build the dataset.",
105    )
106
107    @model_validator(mode="after")
108    def validate_split_percentages(self) -> "DatasetSplit":
109        total = sum(split.percentage for split in self.splits)
110        if not math.isclose(total, 1.0, rel_tol=1e-9):
111            raise ValueError(f"The sum of split percentages must be 1.0 (got {total})")
112        return self
113
114    @classmethod
115    def from_task(
116        cls,
117        name: str,
118        task: "Task",
119        splits: list[DatasetSplitDefinition],
120        filter_id: DatasetFilterId = "all",
121        description: str | None = None,
122    ):
123        """
124        Build a dataset split from a task.
125        """
126        filter = dataset_filter_from_id(filter_id)
127        split_contents = cls.build_split_contents(task, splits, filter)
128        return cls(
129            parent=task,
130            name=name,
131            description=description,
132            splits=splits,
133            split_contents=split_contents,
134            filter=filter_id,
135        )
136
137    @classmethod
138    def build_split_contents(
139        cls,
140        task: "Task",
141        splits: list[DatasetSplitDefinition],
142        filter: DatasetFilter,
143    ) -> dict[str, list[str]]:
144        valid_ids = []
145        for task_run in task.runs():
146            if filter(task_run):
147                valid_ids.append(task_run.id)
148
149        # Shuffle and split by split percentage
150        random.shuffle(valid_ids)
151        split_contents = {}
152        start_idx = 0
153        remaining_items = len(valid_ids)
154
155        # Handle all splits except the last one
156        for split in splits[:-1]:
157            split_size = round(len(valid_ids) * split.percentage)
158            split_contents[split.name] = valid_ids[start_idx : start_idx + split_size]
159            start_idx += split_size
160            remaining_items -= split_size
161
162        # Last split gets all remaining items (for rounding)
163        if splits:
164            split_contents[splits[-1].name] = valid_ids[start_idx:]
165
166        return split_contents
167
168    def parent_task(self) -> "Task | None":
169        # inline import to avoid circular import
170        from kiln_ai.datamodel import Task
171
172        if not isinstance(self.parent, Task):
173            return None
174        return self.parent
175
176    def missing_count(self) -> int:
177        """
178        Returns:
179            int: the number of task runs that have an ID persisted in this dataset split, but no longer exist in the dataset
180        """
181        parent = self.parent_task()
182        if parent is None:
183            raise ValueError("DatasetSplit has no parent task")
184
185        runs = parent.runs(readonly=True)
186        all_ids = set(run.id for run in runs)
187        all_ids_in_splits = set()
188        for ids in self.split_contents.values():
189            all_ids_in_splits.update(ids)
190        missing = all_ids_in_splits - all_ids
191        return len(missing)
192
193    def _get_runs(self) -> list[TaskRun]:
194        """
195        Get all task runs referenced in this dataset split.
196
197        Returns:
198            list[TaskRun]: list of task runs in this dataset split
199        """
200        parent = self.parent_task()
201        if parent is None:
202            return []
203
204        runs = []
205        all_run_ids = set()
206        for run_ids in self.split_contents.values():
207            all_run_ids.update(run_ids)
208
209        # Find all runs by their IDs
210        for task_run in parent.runs(readonly=True):
211            if task_run.id in all_run_ids:
212                runs.append(task_run)
213
214        return runs
215
216    @staticmethod
217    def compute_tool_info(runs: list[TaskRun]) -> DatasetToolInfo:
218        """
219        Compute tool info from a list of task runs.
220
221        Args:
222            runs: list of task runs to analyze
223
224        Returns:
225            DatasetToolInfo: information about tools used across the task runs
226        """
227
228        has_tool_mismatch = False
229        tools: set[str] | None = None
230
231        for run in runs:
232            # Extract tools from run config, treating missing source/run_config/tools_config as empty tools
233            run_tools: set[str] = set()
234            source = run.output.source if run.output else None
235            if source is not None and source.run_config is not None:
236                tools_config = source.run_config.tools_config
237                if tools_config is not None:
238                    run_tools = set(tools_config.tools)
239
240            # First run establishes the expected tool set (including empty)
241            if tools is None:
242                tools = run_tools
243            elif run_tools != tools:
244                # Mismatch found
245                has_tool_mismatch = True
246                tools = set()
247                break
248
249        # If no valid runs were processed, return empty tools
250        if tools is None:
251            tools = set()
252
253        return DatasetToolInfo(has_tool_mismatch=has_tool_mismatch, tools=sorted(tools))
254
255    def tool_info(self) -> DatasetToolInfo:
256        """
257        Helper method to compute tool info for the dataset split. Iterate through all runs in the dataset split and check the tools used in each run config.
258
259        Returns:
260            DatasetToolInfo: information about tools used across task runs in this dataset split
261        """
262        runs = self._get_runs()
263        tool_info = self.compute_tool_info(runs)
264        return tool_info
class DatasetToolInfo(pydantic.main.BaseModel):
24class DatasetToolInfo(BaseModel):
25    """
26    Information about tools used across task runs in a dataset split.
27    """
28
29    has_tool_mismatch: bool = Field(
30        description="Whether the tools from each run match across all runs in the dataset split."
31    )
32    tools: list[str] = Field(
33        default_factory=list,
34        description="Common tool IDs shared by every run; empty when tools are mismatched or no tools exist.",
35    )

Information about tools used across task runs in a dataset split.

has_tool_mismatch: bool
tools: list[str]
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class DatasetSplitDefinition(pydantic.main.BaseModel):
38class DatasetSplitDefinition(BaseModel):
39    """
40    A definition of a split in a dataset.
41
42    Example: name="train", description="The training set", percentage=0.8 (80% of the dataset)
43    """
44
45    name: FilenameString = Field(
46        description="The name of the dataset split definition."
47    )
48    description: str | None = Field(
49        default=None,
50        description="A description of the dataset for you and your team. Not used in training.",
51    )
52    percentage: float = Field(
53        ge=0.0,
54        le=1.0,
55        description="The percentage of the dataset that this split represents (between 0 and 1).",
56    )

A definition of a split in a dataset.

Example: name="train", description="The training set", percentage=0.8 (80% of the dataset)

name: Annotated[str, BeforeValidator(func=<function name_validator.<locals>.fn at 0x7fa807e65260>, json_schema_input_type=PydanticUndefined)]
description: str | None
percentage: float
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

AllSplitDefinition: list[DatasetSplitDefinition] = [DatasetSplitDefinition(name='all', description=None, percentage=1.0)]
Train80Test20SplitDefinition: list[DatasetSplitDefinition] = [DatasetSplitDefinition(name='train', description=None, percentage=0.8), DatasetSplitDefinition(name='test', description=None, percentage=0.2)]
Train80Val20SplitDefinition: list[DatasetSplitDefinition] = [DatasetSplitDefinition(name='train', description=None, percentage=0.8), DatasetSplitDefinition(name='val', description=None, percentage=0.2)]
Train60Test20Val20SplitDefinition: list[DatasetSplitDefinition] = [DatasetSplitDefinition(name='train', description=None, percentage=0.6), DatasetSplitDefinition(name='test', description=None, percentage=0.2), DatasetSplitDefinition(name='val', description=None, percentage=0.2)]
Train80Test10Val10SplitDefinition: list[DatasetSplitDefinition] = [DatasetSplitDefinition(name='train', description=None, percentage=0.8), DatasetSplitDefinition(name='test', description=None, percentage=0.1), DatasetSplitDefinition(name='val', description=None, percentage=0.1)]
class DatasetSplit(kiln_ai.datamodel.basemodel.KilnParentedModel):
 82class DatasetSplit(KilnParentedModel):
 83    """
 84    A collection of task runs, with optional splits (train, test, validation).
 85
 86    Used to freeze a dataset into train/test/validation splits for repeatable fine-tuning or other tasks.
 87
 88    Maintains a list of IDs for each split, to avoid data duplication.
 89    """
 90
 91    name: FilenameString = Field(description="The name of the dataset split.")
 92    description: str | None = Field(
 93        default=None,
 94        description="A description of the dataset for you and your team. Not used in training.",
 95    )
 96    splits: list[DatasetSplitDefinition] = Field(
 97        default_factory=list,
 98        description="The splits in the dataset.",
 99    )
100    split_contents: dict[str, list[str]] = Field(
101        description="The contents of each split in the dataset. The key is the split name, and the value is a list of task run IDs.",
102    )
103    filter: DatasetFilterId | None = Field(
104        default=None,
105        description="The filter used to build the dataset.",
106    )
107
108    @model_validator(mode="after")
109    def validate_split_percentages(self) -> "DatasetSplit":
110        total = sum(split.percentage for split in self.splits)
111        if not math.isclose(total, 1.0, rel_tol=1e-9):
112            raise ValueError(f"The sum of split percentages must be 1.0 (got {total})")
113        return self
114
115    @classmethod
116    def from_task(
117        cls,
118        name: str,
119        task: "Task",
120        splits: list[DatasetSplitDefinition],
121        filter_id: DatasetFilterId = "all",
122        description: str | None = None,
123    ):
124        """
125        Build a dataset split from a task.
126        """
127        filter = dataset_filter_from_id(filter_id)
128        split_contents = cls.build_split_contents(task, splits, filter)
129        return cls(
130            parent=task,
131            name=name,
132            description=description,
133            splits=splits,
134            split_contents=split_contents,
135            filter=filter_id,
136        )
137
138    @classmethod
139    def build_split_contents(
140        cls,
141        task: "Task",
142        splits: list[DatasetSplitDefinition],
143        filter: DatasetFilter,
144    ) -> dict[str, list[str]]:
145        valid_ids = []
146        for task_run in task.runs():
147            if filter(task_run):
148                valid_ids.append(task_run.id)
149
150        # Shuffle and split by split percentage
151        random.shuffle(valid_ids)
152        split_contents = {}
153        start_idx = 0
154        remaining_items = len(valid_ids)
155
156        # Handle all splits except the last one
157        for split in splits[:-1]:
158            split_size = round(len(valid_ids) * split.percentage)
159            split_contents[split.name] = valid_ids[start_idx : start_idx + split_size]
160            start_idx += split_size
161            remaining_items -= split_size
162
163        # Last split gets all remaining items (for rounding)
164        if splits:
165            split_contents[splits[-1].name] = valid_ids[start_idx:]
166
167        return split_contents
168
169    def parent_task(self) -> "Task | None":
170        # inline import to avoid circular import
171        from kiln_ai.datamodel import Task
172
173        if not isinstance(self.parent, Task):
174            return None
175        return self.parent
176
177    def missing_count(self) -> int:
178        """
179        Returns:
180            int: the number of task runs that have an ID persisted in this dataset split, but no longer exist in the dataset
181        """
182        parent = self.parent_task()
183        if parent is None:
184            raise ValueError("DatasetSplit has no parent task")
185
186        runs = parent.runs(readonly=True)
187        all_ids = set(run.id for run in runs)
188        all_ids_in_splits = set()
189        for ids in self.split_contents.values():
190            all_ids_in_splits.update(ids)
191        missing = all_ids_in_splits - all_ids
192        return len(missing)
193
194    def _get_runs(self) -> list[TaskRun]:
195        """
196        Get all task runs referenced in this dataset split.
197
198        Returns:
199            list[TaskRun]: list of task runs in this dataset split
200        """
201        parent = self.parent_task()
202        if parent is None:
203            return []
204
205        runs = []
206        all_run_ids = set()
207        for run_ids in self.split_contents.values():
208            all_run_ids.update(run_ids)
209
210        # Find all runs by their IDs
211        for task_run in parent.runs(readonly=True):
212            if task_run.id in all_run_ids:
213                runs.append(task_run)
214
215        return runs
216
217    @staticmethod
218    def compute_tool_info(runs: list[TaskRun]) -> DatasetToolInfo:
219        """
220        Compute tool info from a list of task runs.
221
222        Args:
223            runs: list of task runs to analyze
224
225        Returns:
226            DatasetToolInfo: information about tools used across the task runs
227        """
228
229        has_tool_mismatch = False
230        tools: set[str] | None = None
231
232        for run in runs:
233            # Extract tools from run config, treating missing source/run_config/tools_config as empty tools
234            run_tools: set[str] = set()
235            source = run.output.source if run.output else None
236            if source is not None and source.run_config is not None:
237                tools_config = source.run_config.tools_config
238                if tools_config is not None:
239                    run_tools = set(tools_config.tools)
240
241            # First run establishes the expected tool set (including empty)
242            if tools is None:
243                tools = run_tools
244            elif run_tools != tools:
245                # Mismatch found
246                has_tool_mismatch = True
247                tools = set()
248                break
249
250        # If no valid runs were processed, return empty tools
251        if tools is None:
252            tools = set()
253
254        return DatasetToolInfo(has_tool_mismatch=has_tool_mismatch, tools=sorted(tools))
255
256    def tool_info(self) -> DatasetToolInfo:
257        """
258        Helper method to compute tool info for the dataset split. Iterate through all runs in the dataset split and check the tools used in each run config.
259
260        Returns:
261            DatasetToolInfo: information about tools used across task runs in this dataset split
262        """
263        runs = self._get_runs()
264        tool_info = self.compute_tool_info(runs)
265        return tool_info

A collection of task runs, with optional splits (train, test, validation).

Used to freeze a dataset into train/test/validation splits for repeatable fine-tuning or other tasks.

Maintains a list of IDs for each split, to avoid data duplication.

name: Annotated[str, BeforeValidator(func=<function name_validator.<locals>.fn at 0x7fa807e65260>, json_schema_input_type=PydanticUndefined)]
description: str | None
splits: list[DatasetSplitDefinition]
split_contents: dict[str, list[str]]
filter: Optional[Annotated[str, AfterValidator(func=<function <lambda> at 0x7fa806ef1120>)]]
@model_validator(mode='after')
def validate_split_percentages(self) -> DatasetSplit:
108    @model_validator(mode="after")
109    def validate_split_percentages(self) -> "DatasetSplit":
110        total = sum(split.percentage for split in self.splits)
111        if not math.isclose(total, 1.0, rel_tol=1e-9):
112            raise ValueError(f"The sum of split percentages must be 1.0 (got {total})")
113        return self
@classmethod
def from_task( cls, name: str, task: kiln_ai.datamodel.Task, splits: list[DatasetSplitDefinition], filter_id: Annotated[str, AfterValidator(func=<function <lambda>>)] = 'all', description: str | None = None):
115    @classmethod
116    def from_task(
117        cls,
118        name: str,
119        task: "Task",
120        splits: list[DatasetSplitDefinition],
121        filter_id: DatasetFilterId = "all",
122        description: str | None = None,
123    ):
124        """
125        Build a dataset split from a task.
126        """
127        filter = dataset_filter_from_id(filter_id)
128        split_contents = cls.build_split_contents(task, splits, filter)
129        return cls(
130            parent=task,
131            name=name,
132            description=description,
133            splits=splits,
134            split_contents=split_contents,
135            filter=filter_id,
136        )

Build a dataset split from a task.

@classmethod
def build_split_contents( cls, task: kiln_ai.datamodel.Task, splits: list[DatasetSplitDefinition], filter: kiln_ai.datamodel.dataset_filters.DatasetFilter) -> dict[str, list[str]]:
138    @classmethod
139    def build_split_contents(
140        cls,
141        task: "Task",
142        splits: list[DatasetSplitDefinition],
143        filter: DatasetFilter,
144    ) -> dict[str, list[str]]:
145        valid_ids = []
146        for task_run in task.runs():
147            if filter(task_run):
148                valid_ids.append(task_run.id)
149
150        # Shuffle and split by split percentage
151        random.shuffle(valid_ids)
152        split_contents = {}
153        start_idx = 0
154        remaining_items = len(valid_ids)
155
156        # Handle all splits except the last one
157        for split in splits[:-1]:
158            split_size = round(len(valid_ids) * split.percentage)
159            split_contents[split.name] = valid_ids[start_idx : start_idx + split_size]
160            start_idx += split_size
161            remaining_items -= split_size
162
163        # Last split gets all remaining items (for rounding)
164        if splits:
165            split_contents[splits[-1].name] = valid_ids[start_idx:]
166
167        return split_contents
def parent_task(self) -> kiln_ai.datamodel.Task | None:
169    def parent_task(self) -> "Task | None":
170        # inline import to avoid circular import
171        from kiln_ai.datamodel import Task
172
173        if not isinstance(self.parent, Task):
174            return None
175        return self.parent
def missing_count(self) -> int:
177    def missing_count(self) -> int:
178        """
179        Returns:
180            int: the number of task runs that have an ID persisted in this dataset split, but no longer exist in the dataset
181        """
182        parent = self.parent_task()
183        if parent is None:
184            raise ValueError("DatasetSplit has no parent task")
185
186        runs = parent.runs(readonly=True)
187        all_ids = set(run.id for run in runs)
188        all_ids_in_splits = set()
189        for ids in self.split_contents.values():
190            all_ids_in_splits.update(ids)
191        missing = all_ids_in_splits - all_ids
192        return len(missing)

Returns: int: the number of task runs that have an ID persisted in this dataset split, but no longer exist in the dataset

@staticmethod
def compute_tool_info( runs: list[kiln_ai.datamodel.TaskRun]) -> DatasetToolInfo:
217    @staticmethod
218    def compute_tool_info(runs: list[TaskRun]) -> DatasetToolInfo:
219        """
220        Compute tool info from a list of task runs.
221
222        Args:
223            runs: list of task runs to analyze
224
225        Returns:
226            DatasetToolInfo: information about tools used across the task runs
227        """
228
229        has_tool_mismatch = False
230        tools: set[str] | None = None
231
232        for run in runs:
233            # Extract tools from run config, treating missing source/run_config/tools_config as empty tools
234            run_tools: set[str] = set()
235            source = run.output.source if run.output else None
236            if source is not None and source.run_config is not None:
237                tools_config = source.run_config.tools_config
238                if tools_config is not None:
239                    run_tools = set(tools_config.tools)
240
241            # First run establishes the expected tool set (including empty)
242            if tools is None:
243                tools = run_tools
244            elif run_tools != tools:
245                # Mismatch found
246                has_tool_mismatch = True
247                tools = set()
248                break
249
250        # If no valid runs were processed, return empty tools
251        if tools is None:
252            tools = set()
253
254        return DatasetToolInfo(has_tool_mismatch=has_tool_mismatch, tools=sorted(tools))

Compute tool info from a list of task runs.

Args: runs: list of task runs to analyze

Returns: DatasetToolInfo: information about tools used across the task runs

def tool_info(self) -> DatasetToolInfo:
256    def tool_info(self) -> DatasetToolInfo:
257        """
258        Helper method to compute tool info for the dataset split. Iterate through all runs in the dataset split and check the tools used in each run config.
259
260        Returns:
261            DatasetToolInfo: information about tools used across task runs in this dataset split
262        """
263        runs = self._get_runs()
264        tool_info = self.compute_tool_info(runs)
265        return tool_info

Helper method to compute tool info for the dataset split. Iterate through all runs in the dataset split and check the tools used in each run config.

Returns: DatasetToolInfo: information about tools used across task runs in this dataset split

def relationship_name() -> str:
727        def relationship_name_method() -> str:
728            return relationship_name

The type of the None singleton.

def parent_type() -> Type[kiln_ai.datamodel.basemodel.KilnParentModel]:
720        def parent_class_method() -> Type[KilnParentModel]:
721            return cls

The type of the None singleton.

model_config = {'validate_assignment': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

def model_post_init(self: pydantic.main.BaseModel, context: Any, /) -> None:
337def init_private_attributes(self: BaseModel, context: Any, /) -> None:
338    """This function is meant to behave like a BaseModel method to initialise private attributes.
339
340    It takes context as an argument since that's what pydantic-core passes when calling it.
341
342    Args:
343        self: The BaseModel instance.
344        context: The context.
345    """
346    if getattr(self, '__pydantic_private__', None) is None:
347        pydantic_private = {}
348        for name, private_attr in self.__private_attributes__.items():
349            default = private_attr.get_default()
350            if default is not PydanticUndefined:
351                pydantic_private[name] = default
352        object_setattr(self, '__pydantic_private__', pydantic_private)

This function is meant to behave like a BaseModel method to initialise private attributes.

It takes context as an argument since that's what pydantic-core passes when calling it.

Args: self: The BaseModel instance. context: The context.