kiln_ai.datamodel.dataset_split
Tools for splitting datasets into train/test/validation splits. Includes filters for selecting which task runs to include in each split.
1""" 2Tools for splitting datasets into train/test/validation splits. Includes filters for selecting which task runs to include in each split. 3""" 4 5import math 6import random 7from typing import TYPE_CHECKING 8 9from pydantic import BaseModel, Field, model_validator 10 11from kiln_ai.datamodel.basemodel import FilenameString, KilnParentedModel 12from kiln_ai.datamodel.dataset_filters import ( 13 DatasetFilter, 14 DatasetFilterId, 15 dataset_filter_from_id, 16) 17from kiln_ai.datamodel.task_run import TaskRun 18 19if TYPE_CHECKING: 20 from kiln_ai.datamodel.task import Task 21 22 23class DatasetToolInfo(BaseModel): 24 """ 25 Information about tools used across task runs in a dataset split. 26 """ 27 28 has_tool_mismatch: bool = Field( 29 description="Whether the tools from each run match across all runs in the dataset split." 30 ) 31 tools: list[str] = Field( 32 default_factory=list, 33 description="Common tool IDs shared by every run; empty when tools are mismatched or no tools exist.", 34 ) 35 36 37class DatasetSplitDefinition(BaseModel): 38 """ 39 A definition of a split in a dataset. 40 41 Example: name="train", description="The training set", percentage=0.8 (80% of the dataset) 42 """ 43 44 name: FilenameString = Field( 45 description="The name of the dataset split definition." 46 ) 47 description: str | None = Field( 48 default=None, 49 description="A description of the dataset for you and your team. Not used in training.", 50 ) 51 percentage: float = Field( 52 ge=0.0, 53 le=1.0, 54 description="The percentage of the dataset that this split represents (between 0 and 1).", 55 ) 56 57 58AllSplitDefinition: list[DatasetSplitDefinition] = [ 59 DatasetSplitDefinition(name="all", percentage=1.0) 60] 61Train80Test20SplitDefinition: list[DatasetSplitDefinition] = [ 62 DatasetSplitDefinition(name="train", percentage=0.8), 63 DatasetSplitDefinition(name="test", percentage=0.2), 64] 65Train80Val20SplitDefinition: list[DatasetSplitDefinition] = [ 66 DatasetSplitDefinition(name="train", percentage=0.8), 67 DatasetSplitDefinition(name="val", percentage=0.2), 68] 69Train60Test20Val20SplitDefinition: list[DatasetSplitDefinition] = [ 70 DatasetSplitDefinition(name="train", percentage=0.6), 71 DatasetSplitDefinition(name="test", percentage=0.2), 72 DatasetSplitDefinition(name="val", percentage=0.2), 73] 74Train80Test10Val10SplitDefinition: list[DatasetSplitDefinition] = [ 75 DatasetSplitDefinition(name="train", percentage=0.8), 76 DatasetSplitDefinition(name="test", percentage=0.1), 77 DatasetSplitDefinition(name="val", percentage=0.1), 78] 79 80 81class DatasetSplit(KilnParentedModel): 82 """ 83 A collection of task runs, with optional splits (train, test, validation). 84 85 Used to freeze a dataset into train/test/validation splits for repeatable fine-tuning or other tasks. 86 87 Maintains a list of IDs for each split, to avoid data duplication. 88 """ 89 90 name: FilenameString = Field(description="The name of the dataset split.") 91 description: str | None = Field( 92 default=None, 93 description="A description of the dataset for you and your team. Not used in training.", 94 ) 95 splits: list[DatasetSplitDefinition] = Field( 96 default_factory=list, 97 description="The splits in the dataset.", 98 ) 99 split_contents: dict[str, list[str]] = Field( 100 description="The contents of each split in the dataset. The key is the split name, and the value is a list of task run IDs.", 101 ) 102 filter: DatasetFilterId | None = Field( 103 default=None, 104 description="The filter used to build the dataset.", 105 ) 106 107 @model_validator(mode="after") 108 def validate_split_percentages(self) -> "DatasetSplit": 109 total = sum(split.percentage for split in self.splits) 110 if not math.isclose(total, 1.0, rel_tol=1e-9): 111 raise ValueError(f"The sum of split percentages must be 1.0 (got {total})") 112 return self 113 114 @classmethod 115 def from_task( 116 cls, 117 name: str, 118 task: "Task", 119 splits: list[DatasetSplitDefinition], 120 filter_id: DatasetFilterId = "all", 121 description: str | None = None, 122 ): 123 """ 124 Build a dataset split from a task. 125 """ 126 filter = dataset_filter_from_id(filter_id) 127 split_contents = cls.build_split_contents(task, splits, filter) 128 return cls( 129 parent=task, 130 name=name, 131 description=description, 132 splits=splits, 133 split_contents=split_contents, 134 filter=filter_id, 135 ) 136 137 @classmethod 138 def build_split_contents( 139 cls, 140 task: "Task", 141 splits: list[DatasetSplitDefinition], 142 filter: DatasetFilter, 143 ) -> dict[str, list[str]]: 144 valid_ids = [] 145 for task_run in task.runs(): 146 if filter(task_run): 147 valid_ids.append(task_run.id) 148 149 # Shuffle and split by split percentage 150 random.shuffle(valid_ids) 151 split_contents = {} 152 start_idx = 0 153 remaining_items = len(valid_ids) 154 155 # Handle all splits except the last one 156 for split in splits[:-1]: 157 split_size = round(len(valid_ids) * split.percentage) 158 split_contents[split.name] = valid_ids[start_idx : start_idx + split_size] 159 start_idx += split_size 160 remaining_items -= split_size 161 162 # Last split gets all remaining items (for rounding) 163 if splits: 164 split_contents[splits[-1].name] = valid_ids[start_idx:] 165 166 return split_contents 167 168 def parent_task(self) -> "Task | None": 169 # inline import to avoid circular import 170 from kiln_ai.datamodel import Task 171 172 if not isinstance(self.parent, Task): 173 return None 174 return self.parent 175 176 def missing_count(self) -> int: 177 """ 178 Returns: 179 int: the number of task runs that have an ID persisted in this dataset split, but no longer exist in the dataset 180 """ 181 parent = self.parent_task() 182 if parent is None: 183 raise ValueError("DatasetSplit has no parent task") 184 185 runs = parent.runs(readonly=True) 186 all_ids = set(run.id for run in runs) 187 all_ids_in_splits = set() 188 for ids in self.split_contents.values(): 189 all_ids_in_splits.update(ids) 190 missing = all_ids_in_splits - all_ids 191 return len(missing) 192 193 def _get_runs(self) -> list[TaskRun]: 194 """ 195 Get all task runs referenced in this dataset split. 196 197 Returns: 198 list[TaskRun]: list of task runs in this dataset split 199 """ 200 parent = self.parent_task() 201 if parent is None: 202 return [] 203 204 runs = [] 205 all_run_ids = set() 206 for run_ids in self.split_contents.values(): 207 all_run_ids.update(run_ids) 208 209 # Find all runs by their IDs 210 for task_run in parent.runs(readonly=True): 211 if task_run.id in all_run_ids: 212 runs.append(task_run) 213 214 return runs 215 216 @staticmethod 217 def compute_tool_info(runs: list[TaskRun]) -> DatasetToolInfo: 218 """ 219 Compute tool info from a list of task runs. 220 221 Args: 222 runs: list of task runs to analyze 223 224 Returns: 225 DatasetToolInfo: information about tools used across the task runs 226 """ 227 228 has_tool_mismatch = False 229 tools: set[str] | None = None 230 231 for run in runs: 232 # Extract tools from run config, treating missing source/run_config/tools_config as empty tools 233 run_tools: set[str] = set() 234 source = run.output.source if run.output else None 235 if source is not None and source.run_config is not None: 236 tools_config = source.run_config.tools_config 237 if tools_config is not None: 238 run_tools = set(tools_config.tools) 239 240 # First run establishes the expected tool set (including empty) 241 if tools is None: 242 tools = run_tools 243 elif run_tools != tools: 244 # Mismatch found 245 has_tool_mismatch = True 246 tools = set() 247 break 248 249 # If no valid runs were processed, return empty tools 250 if tools is None: 251 tools = set() 252 253 return DatasetToolInfo(has_tool_mismatch=has_tool_mismatch, tools=sorted(tools)) 254 255 def tool_info(self) -> DatasetToolInfo: 256 """ 257 Helper method to compute tool info for the dataset split. Iterate through all runs in the dataset split and check the tools used in each run config. 258 259 Returns: 260 DatasetToolInfo: information about tools used across task runs in this dataset split 261 """ 262 runs = self._get_runs() 263 tool_info = self.compute_tool_info(runs) 264 return tool_info
24class DatasetToolInfo(BaseModel): 25 """ 26 Information about tools used across task runs in a dataset split. 27 """ 28 29 has_tool_mismatch: bool = Field( 30 description="Whether the tools from each run match across all runs in the dataset split." 31 ) 32 tools: list[str] = Field( 33 default_factory=list, 34 description="Common tool IDs shared by every run; empty when tools are mismatched or no tools exist.", 35 )
Information about tools used across task runs in a dataset split.
38class DatasetSplitDefinition(BaseModel): 39 """ 40 A definition of a split in a dataset. 41 42 Example: name="train", description="The training set", percentage=0.8 (80% of the dataset) 43 """ 44 45 name: FilenameString = Field( 46 description="The name of the dataset split definition." 47 ) 48 description: str | None = Field( 49 default=None, 50 description="A description of the dataset for you and your team. Not used in training.", 51 ) 52 percentage: float = Field( 53 ge=0.0, 54 le=1.0, 55 description="The percentage of the dataset that this split represents (between 0 and 1).", 56 )
A definition of a split in a dataset.
Example: name="train", description="The training set", percentage=0.8 (80% of the dataset)
82class DatasetSplit(KilnParentedModel): 83 """ 84 A collection of task runs, with optional splits (train, test, validation). 85 86 Used to freeze a dataset into train/test/validation splits for repeatable fine-tuning or other tasks. 87 88 Maintains a list of IDs for each split, to avoid data duplication. 89 """ 90 91 name: FilenameString = Field(description="The name of the dataset split.") 92 description: str | None = Field( 93 default=None, 94 description="A description of the dataset for you and your team. Not used in training.", 95 ) 96 splits: list[DatasetSplitDefinition] = Field( 97 default_factory=list, 98 description="The splits in the dataset.", 99 ) 100 split_contents: dict[str, list[str]] = Field( 101 description="The contents of each split in the dataset. The key is the split name, and the value is a list of task run IDs.", 102 ) 103 filter: DatasetFilterId | None = Field( 104 default=None, 105 description="The filter used to build the dataset.", 106 ) 107 108 @model_validator(mode="after") 109 def validate_split_percentages(self) -> "DatasetSplit": 110 total = sum(split.percentage for split in self.splits) 111 if not math.isclose(total, 1.0, rel_tol=1e-9): 112 raise ValueError(f"The sum of split percentages must be 1.0 (got {total})") 113 return self 114 115 @classmethod 116 def from_task( 117 cls, 118 name: str, 119 task: "Task", 120 splits: list[DatasetSplitDefinition], 121 filter_id: DatasetFilterId = "all", 122 description: str | None = None, 123 ): 124 """ 125 Build a dataset split from a task. 126 """ 127 filter = dataset_filter_from_id(filter_id) 128 split_contents = cls.build_split_contents(task, splits, filter) 129 return cls( 130 parent=task, 131 name=name, 132 description=description, 133 splits=splits, 134 split_contents=split_contents, 135 filter=filter_id, 136 ) 137 138 @classmethod 139 def build_split_contents( 140 cls, 141 task: "Task", 142 splits: list[DatasetSplitDefinition], 143 filter: DatasetFilter, 144 ) -> dict[str, list[str]]: 145 valid_ids = [] 146 for task_run in task.runs(): 147 if filter(task_run): 148 valid_ids.append(task_run.id) 149 150 # Shuffle and split by split percentage 151 random.shuffle(valid_ids) 152 split_contents = {} 153 start_idx = 0 154 remaining_items = len(valid_ids) 155 156 # Handle all splits except the last one 157 for split in splits[:-1]: 158 split_size = round(len(valid_ids) * split.percentage) 159 split_contents[split.name] = valid_ids[start_idx : start_idx + split_size] 160 start_idx += split_size 161 remaining_items -= split_size 162 163 # Last split gets all remaining items (for rounding) 164 if splits: 165 split_contents[splits[-1].name] = valid_ids[start_idx:] 166 167 return split_contents 168 169 def parent_task(self) -> "Task | None": 170 # inline import to avoid circular import 171 from kiln_ai.datamodel import Task 172 173 if not isinstance(self.parent, Task): 174 return None 175 return self.parent 176 177 def missing_count(self) -> int: 178 """ 179 Returns: 180 int: the number of task runs that have an ID persisted in this dataset split, but no longer exist in the dataset 181 """ 182 parent = self.parent_task() 183 if parent is None: 184 raise ValueError("DatasetSplit has no parent task") 185 186 runs = parent.runs(readonly=True) 187 all_ids = set(run.id for run in runs) 188 all_ids_in_splits = set() 189 for ids in self.split_contents.values(): 190 all_ids_in_splits.update(ids) 191 missing = all_ids_in_splits - all_ids 192 return len(missing) 193 194 def _get_runs(self) -> list[TaskRun]: 195 """ 196 Get all task runs referenced in this dataset split. 197 198 Returns: 199 list[TaskRun]: list of task runs in this dataset split 200 """ 201 parent = self.parent_task() 202 if parent is None: 203 return [] 204 205 runs = [] 206 all_run_ids = set() 207 for run_ids in self.split_contents.values(): 208 all_run_ids.update(run_ids) 209 210 # Find all runs by their IDs 211 for task_run in parent.runs(readonly=True): 212 if task_run.id in all_run_ids: 213 runs.append(task_run) 214 215 return runs 216 217 @staticmethod 218 def compute_tool_info(runs: list[TaskRun]) -> DatasetToolInfo: 219 """ 220 Compute tool info from a list of task runs. 221 222 Args: 223 runs: list of task runs to analyze 224 225 Returns: 226 DatasetToolInfo: information about tools used across the task runs 227 """ 228 229 has_tool_mismatch = False 230 tools: set[str] | None = None 231 232 for run in runs: 233 # Extract tools from run config, treating missing source/run_config/tools_config as empty tools 234 run_tools: set[str] = set() 235 source = run.output.source if run.output else None 236 if source is not None and source.run_config is not None: 237 tools_config = source.run_config.tools_config 238 if tools_config is not None: 239 run_tools = set(tools_config.tools) 240 241 # First run establishes the expected tool set (including empty) 242 if tools is None: 243 tools = run_tools 244 elif run_tools != tools: 245 # Mismatch found 246 has_tool_mismatch = True 247 tools = set() 248 break 249 250 # If no valid runs were processed, return empty tools 251 if tools is None: 252 tools = set() 253 254 return DatasetToolInfo(has_tool_mismatch=has_tool_mismatch, tools=sorted(tools)) 255 256 def tool_info(self) -> DatasetToolInfo: 257 """ 258 Helper method to compute tool info for the dataset split. Iterate through all runs in the dataset split and check the tools used in each run config. 259 260 Returns: 261 DatasetToolInfo: information about tools used across task runs in this dataset split 262 """ 263 runs = self._get_runs() 264 tool_info = self.compute_tool_info(runs) 265 return tool_info
A collection of task runs, with optional splits (train, test, validation).
Used to freeze a dataset into train/test/validation splits for repeatable fine-tuning or other tasks.
Maintains a list of IDs for each split, to avoid data duplication.
108 @model_validator(mode="after") 109 def validate_split_percentages(self) -> "DatasetSplit": 110 total = sum(split.percentage for split in self.splits) 111 if not math.isclose(total, 1.0, rel_tol=1e-9): 112 raise ValueError(f"The sum of split percentages must be 1.0 (got {total})") 113 return self
115 @classmethod 116 def from_task( 117 cls, 118 name: str, 119 task: "Task", 120 splits: list[DatasetSplitDefinition], 121 filter_id: DatasetFilterId = "all", 122 description: str | None = None, 123 ): 124 """ 125 Build a dataset split from a task. 126 """ 127 filter = dataset_filter_from_id(filter_id) 128 split_contents = cls.build_split_contents(task, splits, filter) 129 return cls( 130 parent=task, 131 name=name, 132 description=description, 133 splits=splits, 134 split_contents=split_contents, 135 filter=filter_id, 136 )
Build a dataset split from a task.
138 @classmethod 139 def build_split_contents( 140 cls, 141 task: "Task", 142 splits: list[DatasetSplitDefinition], 143 filter: DatasetFilter, 144 ) -> dict[str, list[str]]: 145 valid_ids = [] 146 for task_run in task.runs(): 147 if filter(task_run): 148 valid_ids.append(task_run.id) 149 150 # Shuffle and split by split percentage 151 random.shuffle(valid_ids) 152 split_contents = {} 153 start_idx = 0 154 remaining_items = len(valid_ids) 155 156 # Handle all splits except the last one 157 for split in splits[:-1]: 158 split_size = round(len(valid_ids) * split.percentage) 159 split_contents[split.name] = valid_ids[start_idx : start_idx + split_size] 160 start_idx += split_size 161 remaining_items -= split_size 162 163 # Last split gets all remaining items (for rounding) 164 if splits: 165 split_contents[splits[-1].name] = valid_ids[start_idx:] 166 167 return split_contents
177 def missing_count(self) -> int: 178 """ 179 Returns: 180 int: the number of task runs that have an ID persisted in this dataset split, but no longer exist in the dataset 181 """ 182 parent = self.parent_task() 183 if parent is None: 184 raise ValueError("DatasetSplit has no parent task") 185 186 runs = parent.runs(readonly=True) 187 all_ids = set(run.id for run in runs) 188 all_ids_in_splits = set() 189 for ids in self.split_contents.values(): 190 all_ids_in_splits.update(ids) 191 missing = all_ids_in_splits - all_ids 192 return len(missing)
Returns: int: the number of task runs that have an ID persisted in this dataset split, but no longer exist in the dataset
217 @staticmethod 218 def compute_tool_info(runs: list[TaskRun]) -> DatasetToolInfo: 219 """ 220 Compute tool info from a list of task runs. 221 222 Args: 223 runs: list of task runs to analyze 224 225 Returns: 226 DatasetToolInfo: information about tools used across the task runs 227 """ 228 229 has_tool_mismatch = False 230 tools: set[str] | None = None 231 232 for run in runs: 233 # Extract tools from run config, treating missing source/run_config/tools_config as empty tools 234 run_tools: set[str] = set() 235 source = run.output.source if run.output else None 236 if source is not None and source.run_config is not None: 237 tools_config = source.run_config.tools_config 238 if tools_config is not None: 239 run_tools = set(tools_config.tools) 240 241 # First run establishes the expected tool set (including empty) 242 if tools is None: 243 tools = run_tools 244 elif run_tools != tools: 245 # Mismatch found 246 has_tool_mismatch = True 247 tools = set() 248 break 249 250 # If no valid runs were processed, return empty tools 251 if tools is None: 252 tools = set() 253 254 return DatasetToolInfo(has_tool_mismatch=has_tool_mismatch, tools=sorted(tools))
Compute tool info from a list of task runs.
Args: runs: list of task runs to analyze
Returns: DatasetToolInfo: information about tools used across the task runs
256 def tool_info(self) -> DatasetToolInfo: 257 """ 258 Helper method to compute tool info for the dataset split. Iterate through all runs in the dataset split and check the tools used in each run config. 259 260 Returns: 261 DatasetToolInfo: information about tools used across task runs in this dataset split 262 """ 263 runs = self._get_runs() 264 tool_info = self.compute_tool_info(runs) 265 return tool_info
Helper method to compute tool info for the dataset split. Iterate through all runs in the dataset split and check the tools used in each run config.
Returns: DatasetToolInfo: information about tools used across task runs in this dataset split
The type of the None singleton.
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
337def init_private_attributes(self: BaseModel, context: Any, /) -> None: 338 """This function is meant to behave like a BaseModel method to initialise private attributes. 339 340 It takes context as an argument since that's what pydantic-core passes when calling it. 341 342 Args: 343 self: The BaseModel instance. 344 context: The context. 345 """ 346 if getattr(self, '__pydantic_private__', None) is None: 347 pydantic_private = {} 348 for name, private_attr in self.__private_attributes__.items(): 349 default = private_attr.get_default() 350 if default is not PydanticUndefined: 351 pydantic_private[name] = default 352 object_setattr(self, '__pydantic_private__', pydantic_private)
This function is meant to behave like a BaseModel method to initialise private attributes.
It takes context as an argument since that's what pydantic-core passes when calling it.
Args: self: The BaseModel instance. context: The context.