kiln_ai.datamodel.dataset_split
Tools for splitting datasets into train/test/validation splits. Includes filters for selecting which task runs to include in each split.
1""" 2Tools for splitting datasets into train/test/validation splits. Includes filters for selecting which task runs to include in each split. 3""" 4 5import math 6import random 7from typing import TYPE_CHECKING 8 9from pydantic import BaseModel, Field, model_validator 10 11from kiln_ai.datamodel.basemodel import FilenameString, KilnParentedModel 12from kiln_ai.datamodel.dataset_filters import ( 13 DatasetFilter, 14 DatasetFilterId, 15 dataset_filter_from_id, 16) 17from kiln_ai.datamodel.run_config import KilnAgentRunConfigProperties 18from kiln_ai.datamodel.task_run import TaskRun 19 20if TYPE_CHECKING: 21 from kiln_ai.datamodel.task import Task 22 23 24class DatasetToolInfo(BaseModel): 25 """ 26 Information about tools used across task runs in a dataset split. 27 """ 28 29 has_tool_mismatch: bool = Field( 30 description="Whether the tools from each run match across all runs in the dataset split." 31 ) 32 tools: list[str] | None = Field( 33 default_factory=list, 34 description="Common tool IDs shared by every run. Empty means every run has no tools. None means the dataset contains mismatched tool sets.", 35 ) 36 37 38class DatasetSplitDefinition(BaseModel): 39 """ 40 A definition of a split in a dataset. 41 42 Example: name="train", description="The training set", percentage=0.8 (80% of the dataset) 43 """ 44 45 name: FilenameString = Field( 46 description="The name of the dataset split definition." 47 ) 48 description: str | None = Field( 49 default=None, 50 description="A description of the dataset for you and your team. Not used in training.", 51 ) 52 percentage: float = Field( 53 ge=0.0, 54 le=1.0, 55 description="The percentage of the dataset that this split represents (between 0 and 1).", 56 ) 57 58 59AllSplitDefinition: list[DatasetSplitDefinition] = [ 60 DatasetSplitDefinition(name="all", percentage=1.0) 61] 62Train80Test20SplitDefinition: list[DatasetSplitDefinition] = [ 63 DatasetSplitDefinition(name="train", percentage=0.8), 64 DatasetSplitDefinition(name="test", percentage=0.2), 65] 66Train80Val20SplitDefinition: list[DatasetSplitDefinition] = [ 67 DatasetSplitDefinition(name="train", percentage=0.8), 68 DatasetSplitDefinition(name="val", percentage=0.2), 69] 70Train60Test20Val20SplitDefinition: list[DatasetSplitDefinition] = [ 71 DatasetSplitDefinition(name="train", percentage=0.6), 72 DatasetSplitDefinition(name="test", percentage=0.2), 73 DatasetSplitDefinition(name="val", percentage=0.2), 74] 75Train80Test10Val10SplitDefinition: list[DatasetSplitDefinition] = [ 76 DatasetSplitDefinition(name="train", percentage=0.8), 77 DatasetSplitDefinition(name="test", percentage=0.1), 78 DatasetSplitDefinition(name="val", percentage=0.1), 79] 80 81 82class DatasetSplit(KilnParentedModel): 83 """ 84 A collection of task runs, with optional splits (train, test, validation). 85 86 Used to freeze a dataset into train/test/validation splits for repeatable fine-tuning or other tasks. 87 88 Maintains a list of IDs for each split, to avoid data duplication. 89 """ 90 91 name: FilenameString = Field(description="The name of the dataset split.") 92 description: str | None = Field( 93 default=None, 94 description="A description of the dataset for you and your team. Not used in training.", 95 ) 96 splits: list[DatasetSplitDefinition] = Field( 97 default_factory=list, 98 description="The splits in the dataset.", 99 ) 100 split_contents: dict[str, list[str]] = Field( 101 description="The contents of each split in the dataset. The key is the split name, and the value is a list of task run IDs.", 102 ) 103 filter: DatasetFilterId | None = Field( 104 default=None, 105 description="The filter used to build the dataset.", 106 ) 107 108 @model_validator(mode="after") 109 def validate_split_percentages(self) -> "DatasetSplit": 110 total = sum(split.percentage for split in self.splits) 111 if not math.isclose(total, 1.0, rel_tol=1e-9): 112 raise ValueError(f"The sum of split percentages must be 1.0 (got {total})") 113 return self 114 115 @classmethod 116 def from_task( 117 cls, 118 name: str, 119 task: "Task", 120 splits: list[DatasetSplitDefinition], 121 filter_id: DatasetFilterId = "all", 122 description: str | None = None, 123 ): 124 """ 125 Build a dataset split from a task. 126 """ 127 filter = dataset_filter_from_id(filter_id) 128 split_contents = cls.build_split_contents(task, splits, filter) 129 return cls( 130 parent=task, 131 name=name, 132 description=description, 133 splits=splits, 134 split_contents=split_contents, 135 filter=filter_id, 136 ) 137 138 @classmethod 139 def build_split_contents( 140 cls, 141 task: "Task", 142 splits: list[DatasetSplitDefinition], 143 filter: DatasetFilter, 144 ) -> dict[str, list[str]]: 145 valid_ids = [] 146 for task_run in task.runs(): 147 if filter(task_run): 148 valid_ids.append(task_run.id) 149 150 # Shuffle and split by split percentage 151 random.shuffle(valid_ids) 152 split_contents = {} 153 start_idx = 0 154 remaining_items = len(valid_ids) 155 156 # Handle all splits except the last one 157 for split in splits[:-1]: 158 split_size = round(len(valid_ids) * split.percentage) 159 split_contents[split.name] = valid_ids[start_idx : start_idx + split_size] 160 start_idx += split_size 161 remaining_items -= split_size 162 163 # Last split gets all remaining items (for rounding) 164 if splits: 165 split_contents[splits[-1].name] = valid_ids[start_idx:] 166 167 return split_contents 168 169 def parent_task(self) -> "Task | None": 170 # inline import to avoid circular import 171 from kiln_ai.datamodel import Task 172 173 if not isinstance(self.parent, Task): 174 return None 175 return self.parent 176 177 def missing_count(self) -> int: 178 """ 179 Returns: 180 int: the number of task runs that have an ID persisted in this dataset split, but no longer exist in the dataset 181 """ 182 parent = self.parent_task() 183 if parent is None: 184 raise ValueError("DatasetSplit has no parent task") 185 186 runs = parent.runs(readonly=True) 187 all_ids = set(run.id for run in runs) 188 all_ids_in_splits = set() 189 for ids in self.split_contents.values(): 190 all_ids_in_splits.update(ids) 191 missing = all_ids_in_splits - all_ids 192 return len(missing) 193 194 def _get_runs(self) -> list[TaskRun]: 195 """ 196 Get all task runs referenced in this dataset split. 197 198 Returns: 199 list[TaskRun]: list of task runs in this dataset split 200 """ 201 parent = self.parent_task() 202 if parent is None: 203 return [] 204 205 runs = [] 206 all_run_ids = set() 207 for run_ids in self.split_contents.values(): 208 all_run_ids.update(run_ids) 209 210 # Find all runs by their IDs 211 for task_run in parent.runs(readonly=True): 212 if task_run.id in all_run_ids: 213 runs.append(task_run) 214 215 return runs 216 217 @staticmethod 218 def compute_tool_info(runs: list[TaskRun]) -> DatasetToolInfo: 219 """ 220 Compute tool info from a list of task runs. 221 222 Args: 223 runs: list of task runs to analyze 224 225 Returns: 226 DatasetToolInfo: information about tools used across the task runs 227 """ 228 229 has_tool_mismatch = False 230 tools: set[str] | None = None 231 232 for run in runs: 233 # Extract tools from run config, treating missing source/run_config/tools_config as empty tools 234 run_tools: set[str] = set() 235 source = run.output.source if run.output else None 236 if source is not None and isinstance( 237 source.run_config, KilnAgentRunConfigProperties 238 ): 239 tools_config = source.run_config.tools_config 240 if tools_config is not None: 241 run_tools = set(tools_config.tools) 242 243 # First run establishes the expected tool set (including empty) 244 if tools is None: 245 tools = run_tools 246 elif run_tools != tools: 247 # Mismatch found 248 has_tool_mismatch = True 249 tools = None 250 break 251 252 # If no valid runs were processed, return empty tools 253 if tools is None: 254 if not has_tool_mismatch: 255 tools = set() 256 257 return DatasetToolInfo( 258 has_tool_mismatch=has_tool_mismatch, 259 tools=None if tools is None else sorted(tools), 260 ) 261 262 def tool_info(self) -> DatasetToolInfo: 263 """ 264 Helper method to compute tool info for the dataset split. Iterate through all runs in the dataset split and check the tools used in each run config. 265 266 Returns: 267 DatasetToolInfo: information about tools used across task runs in this dataset split 268 """ 269 runs = self._get_runs() 270 tool_info = self.compute_tool_info(runs) 271 return tool_info
25class DatasetToolInfo(BaseModel): 26 """ 27 Information about tools used across task runs in a dataset split. 28 """ 29 30 has_tool_mismatch: bool = Field( 31 description="Whether the tools from each run match across all runs in the dataset split." 32 ) 33 tools: list[str] | None = Field( 34 default_factory=list, 35 description="Common tool IDs shared by every run. Empty means every run has no tools. None means the dataset contains mismatched tool sets.", 36 )
Information about tools used across task runs in a dataset split.
39class DatasetSplitDefinition(BaseModel): 40 """ 41 A definition of a split in a dataset. 42 43 Example: name="train", description="The training set", percentage=0.8 (80% of the dataset) 44 """ 45 46 name: FilenameString = Field( 47 description="The name of the dataset split definition." 48 ) 49 description: str | None = Field( 50 default=None, 51 description="A description of the dataset for you and your team. Not used in training.", 52 ) 53 percentage: float = Field( 54 ge=0.0, 55 le=1.0, 56 description="The percentage of the dataset that this split represents (between 0 and 1).", 57 )
A definition of a split in a dataset.
Example: name="train", description="The training set", percentage=0.8 (80% of the dataset)
83class DatasetSplit(KilnParentedModel): 84 """ 85 A collection of task runs, with optional splits (train, test, validation). 86 87 Used to freeze a dataset into train/test/validation splits for repeatable fine-tuning or other tasks. 88 89 Maintains a list of IDs for each split, to avoid data duplication. 90 """ 91 92 name: FilenameString = Field(description="The name of the dataset split.") 93 description: str | None = Field( 94 default=None, 95 description="A description of the dataset for you and your team. Not used in training.", 96 ) 97 splits: list[DatasetSplitDefinition] = Field( 98 default_factory=list, 99 description="The splits in the dataset.", 100 ) 101 split_contents: dict[str, list[str]] = Field( 102 description="The contents of each split in the dataset. The key is the split name, and the value is a list of task run IDs.", 103 ) 104 filter: DatasetFilterId | None = Field( 105 default=None, 106 description="The filter used to build the dataset.", 107 ) 108 109 @model_validator(mode="after") 110 def validate_split_percentages(self) -> "DatasetSplit": 111 total = sum(split.percentage for split in self.splits) 112 if not math.isclose(total, 1.0, rel_tol=1e-9): 113 raise ValueError(f"The sum of split percentages must be 1.0 (got {total})") 114 return self 115 116 @classmethod 117 def from_task( 118 cls, 119 name: str, 120 task: "Task", 121 splits: list[DatasetSplitDefinition], 122 filter_id: DatasetFilterId = "all", 123 description: str | None = None, 124 ): 125 """ 126 Build a dataset split from a task. 127 """ 128 filter = dataset_filter_from_id(filter_id) 129 split_contents = cls.build_split_contents(task, splits, filter) 130 return cls( 131 parent=task, 132 name=name, 133 description=description, 134 splits=splits, 135 split_contents=split_contents, 136 filter=filter_id, 137 ) 138 139 @classmethod 140 def build_split_contents( 141 cls, 142 task: "Task", 143 splits: list[DatasetSplitDefinition], 144 filter: DatasetFilter, 145 ) -> dict[str, list[str]]: 146 valid_ids = [] 147 for task_run in task.runs(): 148 if filter(task_run): 149 valid_ids.append(task_run.id) 150 151 # Shuffle and split by split percentage 152 random.shuffle(valid_ids) 153 split_contents = {} 154 start_idx = 0 155 remaining_items = len(valid_ids) 156 157 # Handle all splits except the last one 158 for split in splits[:-1]: 159 split_size = round(len(valid_ids) * split.percentage) 160 split_contents[split.name] = valid_ids[start_idx : start_idx + split_size] 161 start_idx += split_size 162 remaining_items -= split_size 163 164 # Last split gets all remaining items (for rounding) 165 if splits: 166 split_contents[splits[-1].name] = valid_ids[start_idx:] 167 168 return split_contents 169 170 def parent_task(self) -> "Task | None": 171 # inline import to avoid circular import 172 from kiln_ai.datamodel import Task 173 174 if not isinstance(self.parent, Task): 175 return None 176 return self.parent 177 178 def missing_count(self) -> int: 179 """ 180 Returns: 181 int: the number of task runs that have an ID persisted in this dataset split, but no longer exist in the dataset 182 """ 183 parent = self.parent_task() 184 if parent is None: 185 raise ValueError("DatasetSplit has no parent task") 186 187 runs = parent.runs(readonly=True) 188 all_ids = set(run.id for run in runs) 189 all_ids_in_splits = set() 190 for ids in self.split_contents.values(): 191 all_ids_in_splits.update(ids) 192 missing = all_ids_in_splits - all_ids 193 return len(missing) 194 195 def _get_runs(self) -> list[TaskRun]: 196 """ 197 Get all task runs referenced in this dataset split. 198 199 Returns: 200 list[TaskRun]: list of task runs in this dataset split 201 """ 202 parent = self.parent_task() 203 if parent is None: 204 return [] 205 206 runs = [] 207 all_run_ids = set() 208 for run_ids in self.split_contents.values(): 209 all_run_ids.update(run_ids) 210 211 # Find all runs by their IDs 212 for task_run in parent.runs(readonly=True): 213 if task_run.id in all_run_ids: 214 runs.append(task_run) 215 216 return runs 217 218 @staticmethod 219 def compute_tool_info(runs: list[TaskRun]) -> DatasetToolInfo: 220 """ 221 Compute tool info from a list of task runs. 222 223 Args: 224 runs: list of task runs to analyze 225 226 Returns: 227 DatasetToolInfo: information about tools used across the task runs 228 """ 229 230 has_tool_mismatch = False 231 tools: set[str] | None = None 232 233 for run in runs: 234 # Extract tools from run config, treating missing source/run_config/tools_config as empty tools 235 run_tools: set[str] = set() 236 source = run.output.source if run.output else None 237 if source is not None and isinstance( 238 source.run_config, KilnAgentRunConfigProperties 239 ): 240 tools_config = source.run_config.tools_config 241 if tools_config is not None: 242 run_tools = set(tools_config.tools) 243 244 # First run establishes the expected tool set (including empty) 245 if tools is None: 246 tools = run_tools 247 elif run_tools != tools: 248 # Mismatch found 249 has_tool_mismatch = True 250 tools = None 251 break 252 253 # If no valid runs were processed, return empty tools 254 if tools is None: 255 if not has_tool_mismatch: 256 tools = set() 257 258 return DatasetToolInfo( 259 has_tool_mismatch=has_tool_mismatch, 260 tools=None if tools is None else sorted(tools), 261 ) 262 263 def tool_info(self) -> DatasetToolInfo: 264 """ 265 Helper method to compute tool info for the dataset split. Iterate through all runs in the dataset split and check the tools used in each run config. 266 267 Returns: 268 DatasetToolInfo: information about tools used across task runs in this dataset split 269 """ 270 runs = self._get_runs() 271 tool_info = self.compute_tool_info(runs) 272 return tool_info
A collection of task runs, with optional splits (train, test, validation).
Used to freeze a dataset into train/test/validation splits for repeatable fine-tuning or other tasks.
Maintains a list of IDs for each split, to avoid data duplication.
109 @model_validator(mode="after") 110 def validate_split_percentages(self) -> "DatasetSplit": 111 total = sum(split.percentage for split in self.splits) 112 if not math.isclose(total, 1.0, rel_tol=1e-9): 113 raise ValueError(f"The sum of split percentages must be 1.0 (got {total})") 114 return self
116 @classmethod 117 def from_task( 118 cls, 119 name: str, 120 task: "Task", 121 splits: list[DatasetSplitDefinition], 122 filter_id: DatasetFilterId = "all", 123 description: str | None = None, 124 ): 125 """ 126 Build a dataset split from a task. 127 """ 128 filter = dataset_filter_from_id(filter_id) 129 split_contents = cls.build_split_contents(task, splits, filter) 130 return cls( 131 parent=task, 132 name=name, 133 description=description, 134 splits=splits, 135 split_contents=split_contents, 136 filter=filter_id, 137 )
Build a dataset split from a task.
139 @classmethod 140 def build_split_contents( 141 cls, 142 task: "Task", 143 splits: list[DatasetSplitDefinition], 144 filter: DatasetFilter, 145 ) -> dict[str, list[str]]: 146 valid_ids = [] 147 for task_run in task.runs(): 148 if filter(task_run): 149 valid_ids.append(task_run.id) 150 151 # Shuffle and split by split percentage 152 random.shuffle(valid_ids) 153 split_contents = {} 154 start_idx = 0 155 remaining_items = len(valid_ids) 156 157 # Handle all splits except the last one 158 for split in splits[:-1]: 159 split_size = round(len(valid_ids) * split.percentage) 160 split_contents[split.name] = valid_ids[start_idx : start_idx + split_size] 161 start_idx += split_size 162 remaining_items -= split_size 163 164 # Last split gets all remaining items (for rounding) 165 if splits: 166 split_contents[splits[-1].name] = valid_ids[start_idx:] 167 168 return split_contents
178 def missing_count(self) -> int: 179 """ 180 Returns: 181 int: the number of task runs that have an ID persisted in this dataset split, but no longer exist in the dataset 182 """ 183 parent = self.parent_task() 184 if parent is None: 185 raise ValueError("DatasetSplit has no parent task") 186 187 runs = parent.runs(readonly=True) 188 all_ids = set(run.id for run in runs) 189 all_ids_in_splits = set() 190 for ids in self.split_contents.values(): 191 all_ids_in_splits.update(ids) 192 missing = all_ids_in_splits - all_ids 193 return len(missing)
Returns: int: the number of task runs that have an ID persisted in this dataset split, but no longer exist in the dataset
218 @staticmethod 219 def compute_tool_info(runs: list[TaskRun]) -> DatasetToolInfo: 220 """ 221 Compute tool info from a list of task runs. 222 223 Args: 224 runs: list of task runs to analyze 225 226 Returns: 227 DatasetToolInfo: information about tools used across the task runs 228 """ 229 230 has_tool_mismatch = False 231 tools: set[str] | None = None 232 233 for run in runs: 234 # Extract tools from run config, treating missing source/run_config/tools_config as empty tools 235 run_tools: set[str] = set() 236 source = run.output.source if run.output else None 237 if source is not None and isinstance( 238 source.run_config, KilnAgentRunConfigProperties 239 ): 240 tools_config = source.run_config.tools_config 241 if tools_config is not None: 242 run_tools = set(tools_config.tools) 243 244 # First run establishes the expected tool set (including empty) 245 if tools is None: 246 tools = run_tools 247 elif run_tools != tools: 248 # Mismatch found 249 has_tool_mismatch = True 250 tools = None 251 break 252 253 # If no valid runs were processed, return empty tools 254 if tools is None: 255 if not has_tool_mismatch: 256 tools = set() 257 258 return DatasetToolInfo( 259 has_tool_mismatch=has_tool_mismatch, 260 tools=None if tools is None else sorted(tools), 261 )
Compute tool info from a list of task runs.
Args: runs: list of task runs to analyze
Returns: DatasetToolInfo: information about tools used across the task runs
263 def tool_info(self) -> DatasetToolInfo: 264 """ 265 Helper method to compute tool info for the dataset split. Iterate through all runs in the dataset split and check the tools used in each run config. 266 267 Returns: 268 DatasetToolInfo: information about tools used across task runs in this dataset split 269 """ 270 runs = self._get_runs() 271 tool_info = self.compute_tool_info(runs) 272 return tool_info
Helper method to compute tool info for the dataset split. Iterate through all runs in the dataset split and check the tools used in each run config.
Returns: DatasetToolInfo: information about tools used across task runs in this dataset split
The type of the None singleton.
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
337def init_private_attributes(self: BaseModel, context: Any, /) -> None: 338 """This function is meant to behave like a BaseModel method to initialise private attributes. 339 340 It takes context as an argument since that's what pydantic-core passes when calling it. 341 342 Args: 343 self: The BaseModel instance. 344 context: The context. 345 """ 346 if getattr(self, '__pydantic_private__', None) is None: 347 pydantic_private = {} 348 for name, private_attr in self.__private_attributes__.items(): 349 default = private_attr.get_default() 350 if default is not PydanticUndefined: 351 pydantic_private[name] = default 352 object_setattr(self, '__pydantic_private__', pydantic_private)
This function is meant to behave like a BaseModel method to initialise private attributes.
It takes context as an argument since that's what pydantic-core passes when calling it.
Args: self: The BaseModel instance. context: The context.