kiln_ai.datamodel.dataset_split
Tools for splitting datasets into train/test/validation splits. Includes filters for selecting which task runs to include in each split.
1""" 2Tools for splitting datasets into train/test/validation splits. Includes filters for selecting which task runs to include in each split. 3""" 4 5import math 6import random 7from typing import TYPE_CHECKING 8 9from pydantic import BaseModel, Field, model_validator 10 11from kiln_ai.datamodel.basemodel import NAME_FIELD, KilnParentedModel 12from kiln_ai.datamodel.dataset_filters import ( 13 DatasetFilter, 14 DatasetFilterId, 15 dataset_filter_from_id, 16) 17 18if TYPE_CHECKING: 19 from kiln_ai.datamodel.task import Task 20 21 22class DatasetSplitDefinition(BaseModel): 23 """ 24 A definition of a split in a dataset. 25 26 Example: name="train", description="The training set", percentage=0.8 (80% of the dataset) 27 """ 28 29 name: str = NAME_FIELD 30 description: str | None = Field( 31 default=None, 32 description="A description of the dataset for you and your team. Not used in training.", 33 ) 34 percentage: float = Field( 35 ge=0.0, 36 le=1.0, 37 description="The percentage of the dataset that this split represents (between 0 and 1).", 38 ) 39 40 41AllSplitDefinition: list[DatasetSplitDefinition] = [ 42 DatasetSplitDefinition(name="all", percentage=1.0) 43] 44Train80Test20SplitDefinition: list[DatasetSplitDefinition] = [ 45 DatasetSplitDefinition(name="train", percentage=0.8), 46 DatasetSplitDefinition(name="test", percentage=0.2), 47] 48Train80Val20SplitDefinition: list[DatasetSplitDefinition] = [ 49 DatasetSplitDefinition(name="train", percentage=0.8), 50 DatasetSplitDefinition(name="val", percentage=0.2), 51] 52Train60Test20Val20SplitDefinition: list[DatasetSplitDefinition] = [ 53 DatasetSplitDefinition(name="train", percentage=0.6), 54 DatasetSplitDefinition(name="test", percentage=0.2), 55 DatasetSplitDefinition(name="val", percentage=0.2), 56] 57Train80Test10Val10SplitDefinition: list[DatasetSplitDefinition] = [ 58 DatasetSplitDefinition(name="train", percentage=0.8), 59 DatasetSplitDefinition(name="test", percentage=0.1), 60 DatasetSplitDefinition(name="val", percentage=0.1), 61] 62 63 64class DatasetSplit(KilnParentedModel): 65 """ 66 A collection of task runs, with optional splits (train, test, validation). 67 68 Used to freeze a dataset into train/test/validation splits for repeatable fine-tuning or other tasks. 69 70 Maintains a list of IDs for each split, to avoid data duplication. 71 """ 72 73 name: str = NAME_FIELD 74 description: str | None = Field( 75 default=None, 76 description="A description of the dataset for you and your team. Not used in training.", 77 ) 78 splits: list[DatasetSplitDefinition] = Field( 79 default_factory=list, 80 description="The splits in the dataset.", 81 ) 82 split_contents: dict[str, list[str]] = Field( 83 description="The contents of each split in the dataset. The key is the split name, and the value is a list of task run IDs.", 84 ) 85 filter: DatasetFilterId | None = Field( 86 default=None, 87 description="The filter used to build the dataset.", 88 ) 89 90 @model_validator(mode="after") 91 def validate_split_percentages(self) -> "DatasetSplit": 92 total = sum(split.percentage for split in self.splits) 93 if not math.isclose(total, 1.0, rel_tol=1e-9): 94 raise ValueError(f"The sum of split percentages must be 1.0 (got {total})") 95 return self 96 97 @classmethod 98 def from_task( 99 cls, 100 name: str, 101 task: "Task", 102 splits: list[DatasetSplitDefinition], 103 filter_id: DatasetFilterId = "all", 104 description: str | None = None, 105 ): 106 """ 107 Build a dataset split from a task. 108 """ 109 filter = dataset_filter_from_id(filter_id) 110 split_contents = cls.build_split_contents(task, splits, filter) 111 return cls( 112 parent=task, 113 name=name, 114 description=description, 115 splits=splits, 116 split_contents=split_contents, 117 filter=filter_id, 118 ) 119 120 @classmethod 121 def build_split_contents( 122 cls, 123 task: "Task", 124 splits: list[DatasetSplitDefinition], 125 filter: DatasetFilter, 126 ) -> dict[str, list[str]]: 127 valid_ids = [] 128 for task_run in task.runs(): 129 if filter(task_run): 130 valid_ids.append(task_run.id) 131 132 # Shuffle and split by split percentage 133 random.shuffle(valid_ids) 134 split_contents = {} 135 start_idx = 0 136 remaining_items = len(valid_ids) 137 138 # Handle all splits except the last one 139 for split in splits[:-1]: 140 split_size = round(len(valid_ids) * split.percentage) 141 split_contents[split.name] = valid_ids[start_idx : start_idx + split_size] 142 start_idx += split_size 143 remaining_items -= split_size 144 145 # Last split gets all remaining items (for rounding) 146 if splits: 147 split_contents[splits[-1].name] = valid_ids[start_idx:] 148 149 return split_contents 150 151 def parent_task(self) -> "Task | None": 152 # inline import to avoid circular import 153 from kiln_ai.datamodel import Task 154 155 if not isinstance(self.parent, Task): 156 return None 157 return self.parent 158 159 def missing_count(self) -> int: 160 """ 161 Returns: 162 int: the number of task runs that have an ID persisted in this dataset split, but no longer exist in the dataset 163 """ 164 parent = self.parent_task() 165 if parent is None: 166 raise ValueError("DatasetSplit has no parent task") 167 168 runs = parent.runs(readonly=True) 169 all_ids = set(run.id for run in runs) 170 all_ids_in_splits = set() 171 for ids in self.split_contents.values(): 172 all_ids_in_splits.update(ids) 173 missing = all_ids_in_splits - all_ids 174 return len(missing)
23class DatasetSplitDefinition(BaseModel): 24 """ 25 A definition of a split in a dataset. 26 27 Example: name="train", description="The training set", percentage=0.8 (80% of the dataset) 28 """ 29 30 name: str = NAME_FIELD 31 description: str | None = Field( 32 default=None, 33 description="A description of the dataset for you and your team. Not used in training.", 34 ) 35 percentage: float = Field( 36 ge=0.0, 37 le=1.0, 38 description="The percentage of the dataset that this split represents (between 0 and 1).", 39 )
A definition of a split in a dataset.
Example: name="train", description="The training set", percentage=0.8 (80% of the dataset)
65class DatasetSplit(KilnParentedModel): 66 """ 67 A collection of task runs, with optional splits (train, test, validation). 68 69 Used to freeze a dataset into train/test/validation splits for repeatable fine-tuning or other tasks. 70 71 Maintains a list of IDs for each split, to avoid data duplication. 72 """ 73 74 name: str = NAME_FIELD 75 description: str | None = Field( 76 default=None, 77 description="A description of the dataset for you and your team. Not used in training.", 78 ) 79 splits: list[DatasetSplitDefinition] = Field( 80 default_factory=list, 81 description="The splits in the dataset.", 82 ) 83 split_contents: dict[str, list[str]] = Field( 84 description="The contents of each split in the dataset. The key is the split name, and the value is a list of task run IDs.", 85 ) 86 filter: DatasetFilterId | None = Field( 87 default=None, 88 description="The filter used to build the dataset.", 89 ) 90 91 @model_validator(mode="after") 92 def validate_split_percentages(self) -> "DatasetSplit": 93 total = sum(split.percentage for split in self.splits) 94 if not math.isclose(total, 1.0, rel_tol=1e-9): 95 raise ValueError(f"The sum of split percentages must be 1.0 (got {total})") 96 return self 97 98 @classmethod 99 def from_task( 100 cls, 101 name: str, 102 task: "Task", 103 splits: list[DatasetSplitDefinition], 104 filter_id: DatasetFilterId = "all", 105 description: str | None = None, 106 ): 107 """ 108 Build a dataset split from a task. 109 """ 110 filter = dataset_filter_from_id(filter_id) 111 split_contents = cls.build_split_contents(task, splits, filter) 112 return cls( 113 parent=task, 114 name=name, 115 description=description, 116 splits=splits, 117 split_contents=split_contents, 118 filter=filter_id, 119 ) 120 121 @classmethod 122 def build_split_contents( 123 cls, 124 task: "Task", 125 splits: list[DatasetSplitDefinition], 126 filter: DatasetFilter, 127 ) -> dict[str, list[str]]: 128 valid_ids = [] 129 for task_run in task.runs(): 130 if filter(task_run): 131 valid_ids.append(task_run.id) 132 133 # Shuffle and split by split percentage 134 random.shuffle(valid_ids) 135 split_contents = {} 136 start_idx = 0 137 remaining_items = len(valid_ids) 138 139 # Handle all splits except the last one 140 for split in splits[:-1]: 141 split_size = round(len(valid_ids) * split.percentage) 142 split_contents[split.name] = valid_ids[start_idx : start_idx + split_size] 143 start_idx += split_size 144 remaining_items -= split_size 145 146 # Last split gets all remaining items (for rounding) 147 if splits: 148 split_contents[splits[-1].name] = valid_ids[start_idx:] 149 150 return split_contents 151 152 def parent_task(self) -> "Task | None": 153 # inline import to avoid circular import 154 from kiln_ai.datamodel import Task 155 156 if not isinstance(self.parent, Task): 157 return None 158 return self.parent 159 160 def missing_count(self) -> int: 161 """ 162 Returns: 163 int: the number of task runs that have an ID persisted in this dataset split, but no longer exist in the dataset 164 """ 165 parent = self.parent_task() 166 if parent is None: 167 raise ValueError("DatasetSplit has no parent task") 168 169 runs = parent.runs(readonly=True) 170 all_ids = set(run.id for run in runs) 171 all_ids_in_splits = set() 172 for ids in self.split_contents.values(): 173 all_ids_in_splits.update(ids) 174 missing = all_ids_in_splits - all_ids 175 return len(missing)
A collection of task runs, with optional splits (train, test, validation).
Used to freeze a dataset into train/test/validation splits for repeatable fine-tuning or other tasks.
Maintains a list of IDs for each split, to avoid data duplication.
98 @classmethod 99 def from_task( 100 cls, 101 name: str, 102 task: "Task", 103 splits: list[DatasetSplitDefinition], 104 filter_id: DatasetFilterId = "all", 105 description: str | None = None, 106 ): 107 """ 108 Build a dataset split from a task. 109 """ 110 filter = dataset_filter_from_id(filter_id) 111 split_contents = cls.build_split_contents(task, splits, filter) 112 return cls( 113 parent=task, 114 name=name, 115 description=description, 116 splits=splits, 117 split_contents=split_contents, 118 filter=filter_id, 119 )
Build a dataset split from a task.
121 @classmethod 122 def build_split_contents( 123 cls, 124 task: "Task", 125 splits: list[DatasetSplitDefinition], 126 filter: DatasetFilter, 127 ) -> dict[str, list[str]]: 128 valid_ids = [] 129 for task_run in task.runs(): 130 if filter(task_run): 131 valid_ids.append(task_run.id) 132 133 # Shuffle and split by split percentage 134 random.shuffle(valid_ids) 135 split_contents = {} 136 start_idx = 0 137 remaining_items = len(valid_ids) 138 139 # Handle all splits except the last one 140 for split in splits[:-1]: 141 split_size = round(len(valid_ids) * split.percentage) 142 split_contents[split.name] = valid_ids[start_idx : start_idx + split_size] 143 start_idx += split_size 144 remaining_items -= split_size 145 146 # Last split gets all remaining items (for rounding) 147 if splits: 148 split_contents[splits[-1].name] = valid_ids[start_idx:] 149 150 return split_contents
160 def missing_count(self) -> int: 161 """ 162 Returns: 163 int: the number of task runs that have an ID persisted in this dataset split, but no longer exist in the dataset 164 """ 165 parent = self.parent_task() 166 if parent is None: 167 raise ValueError("DatasetSplit has no parent task") 168 169 runs = parent.runs(readonly=True) 170 all_ids = set(run.id for run in runs) 171 all_ids_in_splits = set() 172 for ids in self.split_contents.values(): 173 all_ids_in_splits.update(ids) 174 missing = all_ids_in_splits - all_ids 175 return len(missing)
Returns: int: the number of task runs that have an ID persisted in this dataset split, but no longer exist in the dataset
The type of the None singleton.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
122 def wrapped_model_post_init(self: BaseModel, context: Any, /) -> None: 123 """We need to both initialize private attributes and call the user-defined model_post_init 124 method. 125 """ 126 init_private_attributes(self, context) 127 original_model_post_init(self, context)
We need to both initialize private attributes and call the user-defined model_post_init method.