kiln_ai.adapters.fine_tune.base_finetune
1from abc import ABC, abstractmethod 2from typing import Literal 3 4from pydantic import BaseModel 5 6from kiln_ai.datamodel import DatasetSplit, FineTuneStatusType, Task 7from kiln_ai.datamodel import Finetune as FinetuneModel 8from kiln_ai.datamodel.datamodel_enums import ( 9 ChatStrategy, 10 ModelProviderName, 11) 12from kiln_ai.datamodel.run_config import RunConfigProperties 13from kiln_ai.utils.name_generator import generate_memorable_name 14 15 16class FineTuneStatus(BaseModel): 17 """ 18 The status of a fine-tune, including a user friendly message. 19 """ 20 21 status: FineTuneStatusType 22 message: str | None = None 23 error_details: str | None = None 24 25 26class FineTuneParameter(BaseModel): 27 """ 28 A parameter for a fine-tune. Hyperparameters, etc. 29 """ 30 31 name: str 32 type: Literal["string", "int", "float", "bool"] 33 description: str 34 optional: bool = True 35 36 37TYPE_MAP = { 38 "string": str, 39 "int": int, 40 "float": float, 41 "bool": bool, 42} 43 44 45class BaseFinetuneAdapter(ABC): 46 """ 47 A base class for fine-tuning adapters. 48 """ 49 50 def __init__( 51 self, 52 datamodel: FinetuneModel, 53 ): 54 self.datamodel = datamodel 55 56 @classmethod 57 async def create_and_start( 58 cls, 59 dataset: DatasetSplit, 60 provider_id: str, 61 provider_base_model_id: str, 62 train_split_name: str, 63 system_message: str, 64 thinking_instructions: str | None, 65 data_strategy: ChatStrategy, 66 parameters: dict[str, str | int | float | bool] = {}, 67 name: str | None = None, 68 description: str | None = None, 69 validation_split_name: str | None = None, 70 run_config: RunConfigProperties | None = None, 71 ) -> tuple["BaseFinetuneAdapter", FinetuneModel]: 72 """ 73 Create and start a fine-tune. 74 """ 75 76 if not dataset.id: 77 raise ValueError("Dataset must have an id") 78 79 if train_split_name not in dataset.split_contents: 80 raise ValueError(f"Train split {train_split_name} not found in dataset") 81 82 if ( 83 validation_split_name 84 and validation_split_name not in dataset.split_contents 85 ): 86 raise ValueError( 87 f"Validation split {validation_split_name} not found in dataset" 88 ) 89 90 # Raise exception if run config is none 91 if run_config is None: 92 raise ValueError("Run config is required") 93 94 # Default name if not provided 95 if name is None: 96 name = generate_memorable_name() 97 98 cls.validate_parameters(parameters) 99 parent_task = dataset.parent_task() 100 if parent_task is None or not parent_task.path: 101 raise ValueError("Dataset must have a parent task with a path") 102 103 datamodel = FinetuneModel( 104 name=name, 105 description=description, 106 provider=provider_id, 107 base_model_id=provider_base_model_id, 108 dataset_split_id=dataset.id, 109 train_split_name=train_split_name, 110 validation_split_name=validation_split_name, 111 parameters=parameters, 112 system_message=cls.augment_system_message(system_message, parent_task), 113 thinking_instructions=thinking_instructions, 114 parent=parent_task, 115 data_strategy=data_strategy, 116 run_config=run_config, 117 ) 118 119 # Update the run config properties for fine-tuning 120 run_config.model_provider_name = ModelProviderName.kiln_fine_tune 121 run_config.model_name = datamodel.nested_id() 122 run_config.prompt_id = f"fine_tune_prompt::{datamodel.nested_id()}" 123 124 adapter = cls(datamodel) 125 await adapter._start(dataset) 126 127 datamodel.save_to_file() 128 129 return adapter, datamodel 130 131 @classmethod 132 def augment_system_message(cls, system_message: str, task: Task) -> str: 133 """ 134 Augment the system message with additional instructions, such as JSON instructions. 135 """ 136 137 # Base implementation does nothing, can be overridden by subclasses 138 return system_message 139 140 @abstractmethod 141 async def _start(self, dataset: DatasetSplit) -> None: 142 """ 143 Start the fine-tune. 144 """ 145 pass 146 147 @abstractmethod 148 async def status(self) -> FineTuneStatus: 149 """ 150 Get the status of the fine-tune. 151 """ 152 pass 153 154 @classmethod 155 def available_parameters(cls) -> list[FineTuneParameter]: 156 """ 157 Returns a list of parameters that can be provided for this fine-tune. Includes hyperparameters, etc. 158 """ 159 return [] 160 161 @classmethod 162 def validate_parameters( 163 cls, parameters: dict[str, str | int | float | bool] 164 ) -> None: 165 """ 166 Validate the parameters for this fine-tune. 167 """ 168 # Check required parameters and parameter types 169 available_parameters = cls.available_parameters() 170 for parameter in available_parameters: 171 if not parameter.optional and parameter.name not in parameters: 172 raise ValueError(f"Parameter {parameter.name} is required") 173 elif parameter.name in parameters: 174 # check parameter is correct type 175 expected_type = TYPE_MAP[parameter.type] 176 value = parameters[parameter.name] 177 178 # Strict type checking for numeric types 179 if expected_type is float and not isinstance(value, float): 180 if isinstance(value, int): 181 value = float(value) 182 else: 183 raise ValueError( 184 f"Parameter {parameter.name} must be a float, got {type(value)}" 185 ) 186 elif expected_type is int and not isinstance(value, int): 187 raise ValueError( 188 f"Parameter {parameter.name} must be an integer, got {type(value)}" 189 ) 190 elif not isinstance(value, expected_type): 191 raise ValueError( 192 f"Parameter {parameter.name} must be type {expected_type}, got {type(value)}" 193 ) 194 195 allowed_parameters = [p.name for p in available_parameters] 196 for parameter_key in parameters: 197 if parameter_key not in allowed_parameters: 198 raise ValueError(f"Parameter {parameter_key} is not available")
class
FineTuneStatus(pydantic.main.BaseModel):
17class FineTuneStatus(BaseModel): 18 """ 19 The status of a fine-tune, including a user friendly message. 20 """ 21 22 status: FineTuneStatusType 23 message: str | None = None 24 error_details: str | None = None
The status of a fine-tune, including a user friendly message.
class
FineTuneParameter(pydantic.main.BaseModel):
27class FineTuneParameter(BaseModel): 28 """ 29 A parameter for a fine-tune. Hyperparameters, etc. 30 """ 31 32 name: str 33 type: Literal["string", "int", "float", "bool"] 34 description: str 35 optional: bool = True
A parameter for a fine-tune. Hyperparameters, etc.
TYPE_MAP =
{'string': <class 'str'>, 'int': <class 'int'>, 'float': <class 'float'>, 'bool': <class 'bool'>}
class
BaseFinetuneAdapter(abc.ABC):
46class BaseFinetuneAdapter(ABC): 47 """ 48 A base class for fine-tuning adapters. 49 """ 50 51 def __init__( 52 self, 53 datamodel: FinetuneModel, 54 ): 55 self.datamodel = datamodel 56 57 @classmethod 58 async def create_and_start( 59 cls, 60 dataset: DatasetSplit, 61 provider_id: str, 62 provider_base_model_id: str, 63 train_split_name: str, 64 system_message: str, 65 thinking_instructions: str | None, 66 data_strategy: ChatStrategy, 67 parameters: dict[str, str | int | float | bool] = {}, 68 name: str | None = None, 69 description: str | None = None, 70 validation_split_name: str | None = None, 71 run_config: RunConfigProperties | None = None, 72 ) -> tuple["BaseFinetuneAdapter", FinetuneModel]: 73 """ 74 Create and start a fine-tune. 75 """ 76 77 if not dataset.id: 78 raise ValueError("Dataset must have an id") 79 80 if train_split_name not in dataset.split_contents: 81 raise ValueError(f"Train split {train_split_name} not found in dataset") 82 83 if ( 84 validation_split_name 85 and validation_split_name not in dataset.split_contents 86 ): 87 raise ValueError( 88 f"Validation split {validation_split_name} not found in dataset" 89 ) 90 91 # Raise exception if run config is none 92 if run_config is None: 93 raise ValueError("Run config is required") 94 95 # Default name if not provided 96 if name is None: 97 name = generate_memorable_name() 98 99 cls.validate_parameters(parameters) 100 parent_task = dataset.parent_task() 101 if parent_task is None or not parent_task.path: 102 raise ValueError("Dataset must have a parent task with a path") 103 104 datamodel = FinetuneModel( 105 name=name, 106 description=description, 107 provider=provider_id, 108 base_model_id=provider_base_model_id, 109 dataset_split_id=dataset.id, 110 train_split_name=train_split_name, 111 validation_split_name=validation_split_name, 112 parameters=parameters, 113 system_message=cls.augment_system_message(system_message, parent_task), 114 thinking_instructions=thinking_instructions, 115 parent=parent_task, 116 data_strategy=data_strategy, 117 run_config=run_config, 118 ) 119 120 # Update the run config properties for fine-tuning 121 run_config.model_provider_name = ModelProviderName.kiln_fine_tune 122 run_config.model_name = datamodel.nested_id() 123 run_config.prompt_id = f"fine_tune_prompt::{datamodel.nested_id()}" 124 125 adapter = cls(datamodel) 126 await adapter._start(dataset) 127 128 datamodel.save_to_file() 129 130 return adapter, datamodel 131 132 @classmethod 133 def augment_system_message(cls, system_message: str, task: Task) -> str: 134 """ 135 Augment the system message with additional instructions, such as JSON instructions. 136 """ 137 138 # Base implementation does nothing, can be overridden by subclasses 139 return system_message 140 141 @abstractmethod 142 async def _start(self, dataset: DatasetSplit) -> None: 143 """ 144 Start the fine-tune. 145 """ 146 pass 147 148 @abstractmethod 149 async def status(self) -> FineTuneStatus: 150 """ 151 Get the status of the fine-tune. 152 """ 153 pass 154 155 @classmethod 156 def available_parameters(cls) -> list[FineTuneParameter]: 157 """ 158 Returns a list of parameters that can be provided for this fine-tune. Includes hyperparameters, etc. 159 """ 160 return [] 161 162 @classmethod 163 def validate_parameters( 164 cls, parameters: dict[str, str | int | float | bool] 165 ) -> None: 166 """ 167 Validate the parameters for this fine-tune. 168 """ 169 # Check required parameters and parameter types 170 available_parameters = cls.available_parameters() 171 for parameter in available_parameters: 172 if not parameter.optional and parameter.name not in parameters: 173 raise ValueError(f"Parameter {parameter.name} is required") 174 elif parameter.name in parameters: 175 # check parameter is correct type 176 expected_type = TYPE_MAP[parameter.type] 177 value = parameters[parameter.name] 178 179 # Strict type checking for numeric types 180 if expected_type is float and not isinstance(value, float): 181 if isinstance(value, int): 182 value = float(value) 183 else: 184 raise ValueError( 185 f"Parameter {parameter.name} must be a float, got {type(value)}" 186 ) 187 elif expected_type is int and not isinstance(value, int): 188 raise ValueError( 189 f"Parameter {parameter.name} must be an integer, got {type(value)}" 190 ) 191 elif not isinstance(value, expected_type): 192 raise ValueError( 193 f"Parameter {parameter.name} must be type {expected_type}, got {type(value)}" 194 ) 195 196 allowed_parameters = [p.name for p in available_parameters] 197 for parameter_key in parameters: 198 if parameter_key not in allowed_parameters: 199 raise ValueError(f"Parameter {parameter_key} is not available")
A base class for fine-tuning adapters.
@classmethod
async def
create_and_start( cls, dataset: kiln_ai.datamodel.DatasetSplit, provider_id: str, provider_base_model_id: str, train_split_name: str, system_message: str, thinking_instructions: str | None, data_strategy: kiln_ai.datamodel.datamodel_enums.ChatStrategy, parameters: dict[str, str | int | float | bool] = {}, name: str | None = None, description: str | None = None, validation_split_name: str | None = None, run_config: kiln_ai.datamodel.run_config.RunConfigProperties | None = None) -> tuple[BaseFinetuneAdapter, kiln_ai.datamodel.Finetune]:
57 @classmethod 58 async def create_and_start( 59 cls, 60 dataset: DatasetSplit, 61 provider_id: str, 62 provider_base_model_id: str, 63 train_split_name: str, 64 system_message: str, 65 thinking_instructions: str | None, 66 data_strategy: ChatStrategy, 67 parameters: dict[str, str | int | float | bool] = {}, 68 name: str | None = None, 69 description: str | None = None, 70 validation_split_name: str | None = None, 71 run_config: RunConfigProperties | None = None, 72 ) -> tuple["BaseFinetuneAdapter", FinetuneModel]: 73 """ 74 Create and start a fine-tune. 75 """ 76 77 if not dataset.id: 78 raise ValueError("Dataset must have an id") 79 80 if train_split_name not in dataset.split_contents: 81 raise ValueError(f"Train split {train_split_name} not found in dataset") 82 83 if ( 84 validation_split_name 85 and validation_split_name not in dataset.split_contents 86 ): 87 raise ValueError( 88 f"Validation split {validation_split_name} not found in dataset" 89 ) 90 91 # Raise exception if run config is none 92 if run_config is None: 93 raise ValueError("Run config is required") 94 95 # Default name if not provided 96 if name is None: 97 name = generate_memorable_name() 98 99 cls.validate_parameters(parameters) 100 parent_task = dataset.parent_task() 101 if parent_task is None or not parent_task.path: 102 raise ValueError("Dataset must have a parent task with a path") 103 104 datamodel = FinetuneModel( 105 name=name, 106 description=description, 107 provider=provider_id, 108 base_model_id=provider_base_model_id, 109 dataset_split_id=dataset.id, 110 train_split_name=train_split_name, 111 validation_split_name=validation_split_name, 112 parameters=parameters, 113 system_message=cls.augment_system_message(system_message, parent_task), 114 thinking_instructions=thinking_instructions, 115 parent=parent_task, 116 data_strategy=data_strategy, 117 run_config=run_config, 118 ) 119 120 # Update the run config properties for fine-tuning 121 run_config.model_provider_name = ModelProviderName.kiln_fine_tune 122 run_config.model_name = datamodel.nested_id() 123 run_config.prompt_id = f"fine_tune_prompt::{datamodel.nested_id()}" 124 125 adapter = cls(datamodel) 126 await adapter._start(dataset) 127 128 datamodel.save_to_file() 129 130 return adapter, datamodel
Create and start a fine-tune.
@classmethod
def
augment_system_message(cls, system_message: str, task: kiln_ai.datamodel.Task) -> str:
132 @classmethod 133 def augment_system_message(cls, system_message: str, task: Task) -> str: 134 """ 135 Augment the system message with additional instructions, such as JSON instructions. 136 """ 137 138 # Base implementation does nothing, can be overridden by subclasses 139 return system_message
Augment the system message with additional instructions, such as JSON instructions.
148 @abstractmethod 149 async def status(self) -> FineTuneStatus: 150 """ 151 Get the status of the fine-tune. 152 """ 153 pass
Get the status of the fine-tune.
155 @classmethod 156 def available_parameters(cls) -> list[FineTuneParameter]: 157 """ 158 Returns a list of parameters that can be provided for this fine-tune. Includes hyperparameters, etc. 159 """ 160 return []
Returns a list of parameters that can be provided for this fine-tune. Includes hyperparameters, etc.
@classmethod
def
validate_parameters(cls, parameters: dict[str, str | int | float | bool]) -> None:
162 @classmethod 163 def validate_parameters( 164 cls, parameters: dict[str, str | int | float | bool] 165 ) -> None: 166 """ 167 Validate the parameters for this fine-tune. 168 """ 169 # Check required parameters and parameter types 170 available_parameters = cls.available_parameters() 171 for parameter in available_parameters: 172 if not parameter.optional and parameter.name not in parameters: 173 raise ValueError(f"Parameter {parameter.name} is required") 174 elif parameter.name in parameters: 175 # check parameter is correct type 176 expected_type = TYPE_MAP[parameter.type] 177 value = parameters[parameter.name] 178 179 # Strict type checking for numeric types 180 if expected_type is float and not isinstance(value, float): 181 if isinstance(value, int): 182 value = float(value) 183 else: 184 raise ValueError( 185 f"Parameter {parameter.name} must be a float, got {type(value)}" 186 ) 187 elif expected_type is int and not isinstance(value, int): 188 raise ValueError( 189 f"Parameter {parameter.name} must be an integer, got {type(value)}" 190 ) 191 elif not isinstance(value, expected_type): 192 raise ValueError( 193 f"Parameter {parameter.name} must be type {expected_type}, got {type(value)}" 194 ) 195 196 allowed_parameters = [p.name for p in available_parameters] 197 for parameter_key in parameters: 198 if parameter_key not in allowed_parameters: 199 raise ValueError(f"Parameter {parameter_key} is not available")
Validate the parameters for this fine-tune.