kiln_ai.datamodel.dataset_split
Tools for splitting datasets into train/test/validation splits. Includes filters for selecting which task runs to include in each split.
1""" 2Tools for splitting datasets into train/test/validation splits. Includes filters for selecting which task runs to include in each split. 3""" 4 5import math 6import random 7from typing import TYPE_CHECKING 8 9from pydantic import BaseModel, Field, model_validator 10 11from kiln_ai.datamodel.basemodel import NAME_FIELD, KilnParentedModel 12from kiln_ai.datamodel.dataset_filters import ( 13 DatasetFilter, 14 DatasetFilterId, 15 dataset_filter_from_id, 16) 17 18if TYPE_CHECKING: 19 from kiln_ai.datamodel.task import Task 20 21 22class DatasetSplitDefinition(BaseModel): 23 """ 24 A definition of a split in a dataset. 25 26 Example: name="train", description="The training set", percentage=0.8 (80% of the dataset) 27 """ 28 29 name: str = NAME_FIELD 30 description: str | None = Field( 31 default=None, 32 description="A description of the dataset for you and your team. Not used in training.", 33 ) 34 percentage: float = Field( 35 ge=0.0, 36 le=1.0, 37 description="The percentage of the dataset that this split represents (between 0 and 1).", 38 ) 39 40 41AllSplitDefinition: list[DatasetSplitDefinition] = [ 42 DatasetSplitDefinition(name="all", percentage=1.0) 43] 44Train80Test20SplitDefinition: list[DatasetSplitDefinition] = [ 45 DatasetSplitDefinition(name="train", percentage=0.8), 46 DatasetSplitDefinition(name="test", percentage=0.2), 47] 48Train60Test20Val20SplitDefinition: list[DatasetSplitDefinition] = [ 49 DatasetSplitDefinition(name="train", percentage=0.6), 50 DatasetSplitDefinition(name="test", percentage=0.2), 51 DatasetSplitDefinition(name="val", percentage=0.2), 52] 53Train80Test10Val10SplitDefinition: list[DatasetSplitDefinition] = [ 54 DatasetSplitDefinition(name="train", percentage=0.8), 55 DatasetSplitDefinition(name="test", percentage=0.1), 56 DatasetSplitDefinition(name="val", percentage=0.1), 57] 58 59 60class DatasetSplit(KilnParentedModel): 61 """ 62 A collection of task runs, with optional splits (train, test, validation). 63 64 Used to freeze a dataset into train/test/validation splits for repeatable fine-tuning or other tasks. 65 66 Maintains a list of IDs for each split, to avoid data duplication. 67 """ 68 69 name: str = NAME_FIELD 70 description: str | None = Field( 71 default=None, 72 description="A description of the dataset for you and your team. Not used in training.", 73 ) 74 splits: list[DatasetSplitDefinition] = Field( 75 default_factory=list, 76 description="The splits in the dataset.", 77 ) 78 split_contents: dict[str, list[str]] = Field( 79 description="The contents of each split in the dataset. The key is the split name, and the value is a list of task run IDs.", 80 ) 81 filter: DatasetFilterId | None = Field( 82 default=None, 83 description="The filter used to build the dataset.", 84 ) 85 86 @model_validator(mode="after") 87 def validate_split_percentages(self) -> "DatasetSplit": 88 total = sum(split.percentage for split in self.splits) 89 if not math.isclose(total, 1.0, rel_tol=1e-9): 90 raise ValueError(f"The sum of split percentages must be 1.0 (got {total})") 91 return self 92 93 @classmethod 94 def from_task( 95 cls, 96 name: str, 97 task: "Task", 98 splits: list[DatasetSplitDefinition], 99 filter_id: DatasetFilterId = "all", 100 description: str | None = None, 101 ): 102 """ 103 Build a dataset split from a task. 104 """ 105 filter = dataset_filter_from_id(filter_id) 106 split_contents = cls.build_split_contents(task, splits, filter) 107 return cls( 108 parent=task, 109 name=name, 110 description=description, 111 splits=splits, 112 split_contents=split_contents, 113 filter=filter_id, 114 ) 115 116 @classmethod 117 def build_split_contents( 118 cls, 119 task: "Task", 120 splits: list[DatasetSplitDefinition], 121 filter: DatasetFilter, 122 ) -> dict[str, list[str]]: 123 valid_ids = [] 124 for task_run in task.runs(): 125 if filter(task_run): 126 valid_ids.append(task_run.id) 127 128 # Shuffle and split by split percentage 129 random.shuffle(valid_ids) 130 split_contents = {} 131 start_idx = 0 132 remaining_items = len(valid_ids) 133 134 # Handle all splits except the last one 135 for split in splits[:-1]: 136 split_size = round(len(valid_ids) * split.percentage) 137 split_contents[split.name] = valid_ids[start_idx : start_idx + split_size] 138 start_idx += split_size 139 remaining_items -= split_size 140 141 # Last split gets all remaining items (for rounding) 142 if splits: 143 split_contents[splits[-1].name] = valid_ids[start_idx:] 144 145 return split_contents 146 147 def parent_task(self) -> "Task | None": 148 # inline import to avoid circular import 149 from kiln_ai.datamodel import Task 150 151 if not isinstance(self.parent, Task): 152 return None 153 return self.parent 154 155 def missing_count(self) -> int: 156 """ 157 Returns: 158 int: the number of task runs that have an ID persisted in this dataset split, but no longer exist in the dataset 159 """ 160 parent = self.parent_task() 161 if parent is None: 162 raise ValueError("DatasetSplit has no parent task") 163 164 runs = parent.runs(readonly=True) 165 all_ids = set(run.id for run in runs) 166 all_ids_in_splits = set() 167 for ids in self.split_contents.values(): 168 all_ids_in_splits.update(ids) 169 missing = all_ids_in_splits - all_ids 170 return len(missing)
23class DatasetSplitDefinition(BaseModel): 24 """ 25 A definition of a split in a dataset. 26 27 Example: name="train", description="The training set", percentage=0.8 (80% of the dataset) 28 """ 29 30 name: str = NAME_FIELD 31 description: str | None = Field( 32 default=None, 33 description="A description of the dataset for you and your team. Not used in training.", 34 ) 35 percentage: float = Field( 36 ge=0.0, 37 le=1.0, 38 description="The percentage of the dataset that this split represents (between 0 and 1).", 39 )
A definition of a split in a dataset.
Example: name="train", description="The training set", percentage=0.8 (80% of the dataset)
61class DatasetSplit(KilnParentedModel): 62 """ 63 A collection of task runs, with optional splits (train, test, validation). 64 65 Used to freeze a dataset into train/test/validation splits for repeatable fine-tuning or other tasks. 66 67 Maintains a list of IDs for each split, to avoid data duplication. 68 """ 69 70 name: str = NAME_FIELD 71 description: str | None = Field( 72 default=None, 73 description="A description of the dataset for you and your team. Not used in training.", 74 ) 75 splits: list[DatasetSplitDefinition] = Field( 76 default_factory=list, 77 description="The splits in the dataset.", 78 ) 79 split_contents: dict[str, list[str]] = Field( 80 description="The contents of each split in the dataset. The key is the split name, and the value is a list of task run IDs.", 81 ) 82 filter: DatasetFilterId | None = Field( 83 default=None, 84 description="The filter used to build the dataset.", 85 ) 86 87 @model_validator(mode="after") 88 def validate_split_percentages(self) -> "DatasetSplit": 89 total = sum(split.percentage for split in self.splits) 90 if not math.isclose(total, 1.0, rel_tol=1e-9): 91 raise ValueError(f"The sum of split percentages must be 1.0 (got {total})") 92 return self 93 94 @classmethod 95 def from_task( 96 cls, 97 name: str, 98 task: "Task", 99 splits: list[DatasetSplitDefinition], 100 filter_id: DatasetFilterId = "all", 101 description: str | None = None, 102 ): 103 """ 104 Build a dataset split from a task. 105 """ 106 filter = dataset_filter_from_id(filter_id) 107 split_contents = cls.build_split_contents(task, splits, filter) 108 return cls( 109 parent=task, 110 name=name, 111 description=description, 112 splits=splits, 113 split_contents=split_contents, 114 filter=filter_id, 115 ) 116 117 @classmethod 118 def build_split_contents( 119 cls, 120 task: "Task", 121 splits: list[DatasetSplitDefinition], 122 filter: DatasetFilter, 123 ) -> dict[str, list[str]]: 124 valid_ids = [] 125 for task_run in task.runs(): 126 if filter(task_run): 127 valid_ids.append(task_run.id) 128 129 # Shuffle and split by split percentage 130 random.shuffle(valid_ids) 131 split_contents = {} 132 start_idx = 0 133 remaining_items = len(valid_ids) 134 135 # Handle all splits except the last one 136 for split in splits[:-1]: 137 split_size = round(len(valid_ids) * split.percentage) 138 split_contents[split.name] = valid_ids[start_idx : start_idx + split_size] 139 start_idx += split_size 140 remaining_items -= split_size 141 142 # Last split gets all remaining items (for rounding) 143 if splits: 144 split_contents[splits[-1].name] = valid_ids[start_idx:] 145 146 return split_contents 147 148 def parent_task(self) -> "Task | None": 149 # inline import to avoid circular import 150 from kiln_ai.datamodel import Task 151 152 if not isinstance(self.parent, Task): 153 return None 154 return self.parent 155 156 def missing_count(self) -> int: 157 """ 158 Returns: 159 int: the number of task runs that have an ID persisted in this dataset split, but no longer exist in the dataset 160 """ 161 parent = self.parent_task() 162 if parent is None: 163 raise ValueError("DatasetSplit has no parent task") 164 165 runs = parent.runs(readonly=True) 166 all_ids = set(run.id for run in runs) 167 all_ids_in_splits = set() 168 for ids in self.split_contents.values(): 169 all_ids_in_splits.update(ids) 170 missing = all_ids_in_splits - all_ids 171 return len(missing)
A collection of task runs, with optional splits (train, test, validation).
Used to freeze a dataset into train/test/validation splits for repeatable fine-tuning or other tasks.
Maintains a list of IDs for each split, to avoid data duplication.
94 @classmethod 95 def from_task( 96 cls, 97 name: str, 98 task: "Task", 99 splits: list[DatasetSplitDefinition], 100 filter_id: DatasetFilterId = "all", 101 description: str | None = None, 102 ): 103 """ 104 Build a dataset split from a task. 105 """ 106 filter = dataset_filter_from_id(filter_id) 107 split_contents = cls.build_split_contents(task, splits, filter) 108 return cls( 109 parent=task, 110 name=name, 111 description=description, 112 splits=splits, 113 split_contents=split_contents, 114 filter=filter_id, 115 )
Build a dataset split from a task.
117 @classmethod 118 def build_split_contents( 119 cls, 120 task: "Task", 121 splits: list[DatasetSplitDefinition], 122 filter: DatasetFilter, 123 ) -> dict[str, list[str]]: 124 valid_ids = [] 125 for task_run in task.runs(): 126 if filter(task_run): 127 valid_ids.append(task_run.id) 128 129 # Shuffle and split by split percentage 130 random.shuffle(valid_ids) 131 split_contents = {} 132 start_idx = 0 133 remaining_items = len(valid_ids) 134 135 # Handle all splits except the last one 136 for split in splits[:-1]: 137 split_size = round(len(valid_ids) * split.percentage) 138 split_contents[split.name] = valid_ids[start_idx : start_idx + split_size] 139 start_idx += split_size 140 remaining_items -= split_size 141 142 # Last split gets all remaining items (for rounding) 143 if splits: 144 split_contents[splits[-1].name] = valid_ids[start_idx:] 145 146 return split_contents
156 def missing_count(self) -> int: 157 """ 158 Returns: 159 int: the number of task runs that have an ID persisted in this dataset split, but no longer exist in the dataset 160 """ 161 parent = self.parent_task() 162 if parent is None: 163 raise ValueError("DatasetSplit has no parent task") 164 165 runs = parent.runs(readonly=True) 166 all_ids = set(run.id for run in runs) 167 all_ids_in_splits = set() 168 for ids in self.split_contents.values(): 169 all_ids_in_splits.update(ids) 170 missing = all_ids_in_splits - all_ids 171 return len(missing)
Returns: int: the number of task runs that have an ID persisted in this dataset split, but no longer exist in the dataset
The type of the None singleton.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
122 def wrapped_model_post_init(self: BaseModel, context: Any, /) -> None: 123 """We need to both initialize private attributes and call the user-defined model_post_init 124 method. 125 """ 126 init_private_attributes(self, context) 127 original_model_post_init(self, context)
We need to both initialize private attributes and call the user-defined model_post_init method.