kiln_ai.datamodel.dataset_split
Tools for splitting datasets into train/test/validation splits. Includes filters for selecting which task runs to include in each split.
1""" 2Tools for splitting datasets into train/test/validation splits. Includes filters for selecting which task runs to include in each split. 3""" 4 5import math 6import random 7from typing import TYPE_CHECKING 8 9from pydantic import BaseModel, Field, model_validator 10 11from kiln_ai.datamodel.basemodel import FilenameString, KilnParentedModel 12from kiln_ai.datamodel.dataset_filters import ( 13 DatasetFilter, 14 DatasetFilterId, 15 dataset_filter_from_id, 16) 17 18if TYPE_CHECKING: 19 from kiln_ai.datamodel.task import Task 20 21 22class DatasetSplitDefinition(BaseModel): 23 """ 24 A definition of a split in a dataset. 25 26 Example: name="train", description="The training set", percentage=0.8 (80% of the dataset) 27 """ 28 29 name: FilenameString = Field( 30 description="The name of the dataset split definition." 31 ) 32 description: str | None = Field( 33 default=None, 34 description="A description of the dataset for you and your team. Not used in training.", 35 ) 36 percentage: float = Field( 37 ge=0.0, 38 le=1.0, 39 description="The percentage of the dataset that this split represents (between 0 and 1).", 40 ) 41 42 43AllSplitDefinition: list[DatasetSplitDefinition] = [ 44 DatasetSplitDefinition(name="all", percentage=1.0) 45] 46Train80Test20SplitDefinition: list[DatasetSplitDefinition] = [ 47 DatasetSplitDefinition(name="train", percentage=0.8), 48 DatasetSplitDefinition(name="test", percentage=0.2), 49] 50Train80Val20SplitDefinition: list[DatasetSplitDefinition] = [ 51 DatasetSplitDefinition(name="train", percentage=0.8), 52 DatasetSplitDefinition(name="val", percentage=0.2), 53] 54Train60Test20Val20SplitDefinition: list[DatasetSplitDefinition] = [ 55 DatasetSplitDefinition(name="train", percentage=0.6), 56 DatasetSplitDefinition(name="test", percentage=0.2), 57 DatasetSplitDefinition(name="val", percentage=0.2), 58] 59Train80Test10Val10SplitDefinition: list[DatasetSplitDefinition] = [ 60 DatasetSplitDefinition(name="train", percentage=0.8), 61 DatasetSplitDefinition(name="test", percentage=0.1), 62 DatasetSplitDefinition(name="val", percentage=0.1), 63] 64 65 66class DatasetSplit(KilnParentedModel): 67 """ 68 A collection of task runs, with optional splits (train, test, validation). 69 70 Used to freeze a dataset into train/test/validation splits for repeatable fine-tuning or other tasks. 71 72 Maintains a list of IDs for each split, to avoid data duplication. 73 """ 74 75 name: FilenameString = Field(description="The name of the dataset split.") 76 description: str | None = Field( 77 default=None, 78 description="A description of the dataset for you and your team. Not used in training.", 79 ) 80 splits: list[DatasetSplitDefinition] = Field( 81 default_factory=list, 82 description="The splits in the dataset.", 83 ) 84 split_contents: dict[str, list[str]] = Field( 85 description="The contents of each split in the dataset. The key is the split name, and the value is a list of task run IDs.", 86 ) 87 filter: DatasetFilterId | None = Field( 88 default=None, 89 description="The filter used to build the dataset.", 90 ) 91 92 @model_validator(mode="after") 93 def validate_split_percentages(self) -> "DatasetSplit": 94 total = sum(split.percentage for split in self.splits) 95 if not math.isclose(total, 1.0, rel_tol=1e-9): 96 raise ValueError(f"The sum of split percentages must be 1.0 (got {total})") 97 return self 98 99 @classmethod 100 def from_task( 101 cls, 102 name: str, 103 task: "Task", 104 splits: list[DatasetSplitDefinition], 105 filter_id: DatasetFilterId = "all", 106 description: str | None = None, 107 ): 108 """ 109 Build a dataset split from a task. 110 """ 111 filter = dataset_filter_from_id(filter_id) 112 split_contents = cls.build_split_contents(task, splits, filter) 113 return cls( 114 parent=task, 115 name=name, 116 description=description, 117 splits=splits, 118 split_contents=split_contents, 119 filter=filter_id, 120 ) 121 122 @classmethod 123 def build_split_contents( 124 cls, 125 task: "Task", 126 splits: list[DatasetSplitDefinition], 127 filter: DatasetFilter, 128 ) -> dict[str, list[str]]: 129 valid_ids = [] 130 for task_run in task.runs(): 131 if filter(task_run): 132 valid_ids.append(task_run.id) 133 134 # Shuffle and split by split percentage 135 random.shuffle(valid_ids) 136 split_contents = {} 137 start_idx = 0 138 remaining_items = len(valid_ids) 139 140 # Handle all splits except the last one 141 for split in splits[:-1]: 142 split_size = round(len(valid_ids) * split.percentage) 143 split_contents[split.name] = valid_ids[start_idx : start_idx + split_size] 144 start_idx += split_size 145 remaining_items -= split_size 146 147 # Last split gets all remaining items (for rounding) 148 if splits: 149 split_contents[splits[-1].name] = valid_ids[start_idx:] 150 151 return split_contents 152 153 def parent_task(self) -> "Task | None": 154 # inline import to avoid circular import 155 from kiln_ai.datamodel import Task 156 157 if not isinstance(self.parent, Task): 158 return None 159 return self.parent 160 161 def missing_count(self) -> int: 162 """ 163 Returns: 164 int: the number of task runs that have an ID persisted in this dataset split, but no longer exist in the dataset 165 """ 166 parent = self.parent_task() 167 if parent is None: 168 raise ValueError("DatasetSplit has no parent task") 169 170 runs = parent.runs(readonly=True) 171 all_ids = set(run.id for run in runs) 172 all_ids_in_splits = set() 173 for ids in self.split_contents.values(): 174 all_ids_in_splits.update(ids) 175 missing = all_ids_in_splits - all_ids 176 return len(missing)
23class DatasetSplitDefinition(BaseModel): 24 """ 25 A definition of a split in a dataset. 26 27 Example: name="train", description="The training set", percentage=0.8 (80% of the dataset) 28 """ 29 30 name: FilenameString = Field( 31 description="The name of the dataset split definition." 32 ) 33 description: str | None = Field( 34 default=None, 35 description="A description of the dataset for you and your team. Not used in training.", 36 ) 37 percentage: float = Field( 38 ge=0.0, 39 le=1.0, 40 description="The percentage of the dataset that this split represents (between 0 and 1).", 41 )
A definition of a split in a dataset.
Example: name="train", description="The training set", percentage=0.8 (80% of the dataset)
67class DatasetSplit(KilnParentedModel): 68 """ 69 A collection of task runs, with optional splits (train, test, validation). 70 71 Used to freeze a dataset into train/test/validation splits for repeatable fine-tuning or other tasks. 72 73 Maintains a list of IDs for each split, to avoid data duplication. 74 """ 75 76 name: FilenameString = Field(description="The name of the dataset split.") 77 description: str | None = Field( 78 default=None, 79 description="A description of the dataset for you and your team. Not used in training.", 80 ) 81 splits: list[DatasetSplitDefinition] = Field( 82 default_factory=list, 83 description="The splits in the dataset.", 84 ) 85 split_contents: dict[str, list[str]] = Field( 86 description="The contents of each split in the dataset. The key is the split name, and the value is a list of task run IDs.", 87 ) 88 filter: DatasetFilterId | None = Field( 89 default=None, 90 description="The filter used to build the dataset.", 91 ) 92 93 @model_validator(mode="after") 94 def validate_split_percentages(self) -> "DatasetSplit": 95 total = sum(split.percentage for split in self.splits) 96 if not math.isclose(total, 1.0, rel_tol=1e-9): 97 raise ValueError(f"The sum of split percentages must be 1.0 (got {total})") 98 return self 99 100 @classmethod 101 def from_task( 102 cls, 103 name: str, 104 task: "Task", 105 splits: list[DatasetSplitDefinition], 106 filter_id: DatasetFilterId = "all", 107 description: str | None = None, 108 ): 109 """ 110 Build a dataset split from a task. 111 """ 112 filter = dataset_filter_from_id(filter_id) 113 split_contents = cls.build_split_contents(task, splits, filter) 114 return cls( 115 parent=task, 116 name=name, 117 description=description, 118 splits=splits, 119 split_contents=split_contents, 120 filter=filter_id, 121 ) 122 123 @classmethod 124 def build_split_contents( 125 cls, 126 task: "Task", 127 splits: list[DatasetSplitDefinition], 128 filter: DatasetFilter, 129 ) -> dict[str, list[str]]: 130 valid_ids = [] 131 for task_run in task.runs(): 132 if filter(task_run): 133 valid_ids.append(task_run.id) 134 135 # Shuffle and split by split percentage 136 random.shuffle(valid_ids) 137 split_contents = {} 138 start_idx = 0 139 remaining_items = len(valid_ids) 140 141 # Handle all splits except the last one 142 for split in splits[:-1]: 143 split_size = round(len(valid_ids) * split.percentage) 144 split_contents[split.name] = valid_ids[start_idx : start_idx + split_size] 145 start_idx += split_size 146 remaining_items -= split_size 147 148 # Last split gets all remaining items (for rounding) 149 if splits: 150 split_contents[splits[-1].name] = valid_ids[start_idx:] 151 152 return split_contents 153 154 def parent_task(self) -> "Task | None": 155 # inline import to avoid circular import 156 from kiln_ai.datamodel import Task 157 158 if not isinstance(self.parent, Task): 159 return None 160 return self.parent 161 162 def missing_count(self) -> int: 163 """ 164 Returns: 165 int: the number of task runs that have an ID persisted in this dataset split, but no longer exist in the dataset 166 """ 167 parent = self.parent_task() 168 if parent is None: 169 raise ValueError("DatasetSplit has no parent task") 170 171 runs = parent.runs(readonly=True) 172 all_ids = set(run.id for run in runs) 173 all_ids_in_splits = set() 174 for ids in self.split_contents.values(): 175 all_ids_in_splits.update(ids) 176 missing = all_ids_in_splits - all_ids 177 return len(missing)
A collection of task runs, with optional splits (train, test, validation).
Used to freeze a dataset into train/test/validation splits for repeatable fine-tuning or other tasks.
Maintains a list of IDs for each split, to avoid data duplication.
100 @classmethod 101 def from_task( 102 cls, 103 name: str, 104 task: "Task", 105 splits: list[DatasetSplitDefinition], 106 filter_id: DatasetFilterId = "all", 107 description: str | None = None, 108 ): 109 """ 110 Build a dataset split from a task. 111 """ 112 filter = dataset_filter_from_id(filter_id) 113 split_contents = cls.build_split_contents(task, splits, filter) 114 return cls( 115 parent=task, 116 name=name, 117 description=description, 118 splits=splits, 119 split_contents=split_contents, 120 filter=filter_id, 121 )
Build a dataset split from a task.
123 @classmethod 124 def build_split_contents( 125 cls, 126 task: "Task", 127 splits: list[DatasetSplitDefinition], 128 filter: DatasetFilter, 129 ) -> dict[str, list[str]]: 130 valid_ids = [] 131 for task_run in task.runs(): 132 if filter(task_run): 133 valid_ids.append(task_run.id) 134 135 # Shuffle and split by split percentage 136 random.shuffle(valid_ids) 137 split_contents = {} 138 start_idx = 0 139 remaining_items = len(valid_ids) 140 141 # Handle all splits except the last one 142 for split in splits[:-1]: 143 split_size = round(len(valid_ids) * split.percentage) 144 split_contents[split.name] = valid_ids[start_idx : start_idx + split_size] 145 start_idx += split_size 146 remaining_items -= split_size 147 148 # Last split gets all remaining items (for rounding) 149 if splits: 150 split_contents[splits[-1].name] = valid_ids[start_idx:] 151 152 return split_contents
162 def missing_count(self) -> int: 163 """ 164 Returns: 165 int: the number of task runs that have an ID persisted in this dataset split, but no longer exist in the dataset 166 """ 167 parent = self.parent_task() 168 if parent is None: 169 raise ValueError("DatasetSplit has no parent task") 170 171 runs = parent.runs(readonly=True) 172 all_ids = set(run.id for run in runs) 173 all_ids_in_splits = set() 174 for ids in self.split_contents.values(): 175 all_ids_in_splits.update(ids) 176 missing = all_ids_in_splits - all_ids 177 return len(missing)
Returns: int: the number of task runs that have an ID persisted in this dataset split, but no longer exist in the dataset
The type of the None singleton.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
122 def wrapped_model_post_init(self: BaseModel, context: Any, /) -> None: 123 """We need to both initialize private attributes and call the user-defined model_post_init 124 method. 125 """ 126 init_private_attributes(self, context) 127 original_model_post_init(self, context)
We need to both initialize private attributes and call the user-defined model_post_init method.