kiln_ai.adapters.fine_tune.base_finetune
1from abc import ABC, abstractmethod 2from typing import Literal 3 4from pydantic import BaseModel 5 6from kiln_ai.adapters.ml_model_list import built_in_models 7from kiln_ai.datamodel import DatasetSplit, FineTuneStatusType 8from kiln_ai.datamodel import Finetune as FinetuneModel 9from kiln_ai.utils.name_generator import generate_memorable_name 10 11 12class FineTuneStatus(BaseModel): 13 """ 14 The status of a fine-tune, including a user friendly message. 15 """ 16 17 status: FineTuneStatusType 18 message: str | None = None 19 20 21class FineTuneParameter(BaseModel): 22 """ 23 A parameter for a fine-tune. Hyperparameters, etc. 24 """ 25 26 name: str 27 type: Literal["string", "int", "float", "bool"] 28 description: str 29 optional: bool = True 30 31 32TYPE_MAP = { 33 "string": str, 34 "int": int, 35 "float": float, 36 "bool": bool, 37} 38 39 40class BaseFinetuneAdapter(ABC): 41 """ 42 A base class for fine-tuning adapters. 43 """ 44 45 def __init__( 46 self, 47 datamodel: FinetuneModel, 48 ): 49 self.datamodel = datamodel 50 51 @classmethod 52 async def create_and_start( 53 cls, 54 dataset: DatasetSplit, 55 provider_id: str, 56 provider_base_model_id: str, 57 train_split_name: str, 58 system_message: str, 59 parameters: dict[str, str | int | float | bool] = {}, 60 name: str | None = None, 61 description: str | None = None, 62 validation_split_name: str | None = None, 63 ) -> tuple["BaseFinetuneAdapter", FinetuneModel]: 64 """ 65 Create and start a fine-tune. 66 """ 67 68 cls.check_valid_provider_model(provider_id, provider_base_model_id) 69 70 if not dataset.id: 71 raise ValueError("Dataset must have an id") 72 73 if train_split_name not in dataset.split_contents: 74 raise ValueError(f"Train split {train_split_name} not found in dataset") 75 76 if ( 77 validation_split_name 78 and validation_split_name not in dataset.split_contents 79 ): 80 raise ValueError( 81 f"Validation split {validation_split_name} not found in dataset" 82 ) 83 84 # Default name if not provided 85 if name is None: 86 name = generate_memorable_name() 87 88 cls.validate_parameters(parameters) 89 parent_task = dataset.parent_task() 90 if parent_task is None or not parent_task.path: 91 raise ValueError("Dataset must have a parent task with a path") 92 93 datamodel = FinetuneModel( 94 name=name, 95 description=description, 96 provider=provider_id, 97 base_model_id=provider_base_model_id, 98 dataset_split_id=dataset.id, 99 train_split_name=train_split_name, 100 validation_split_name=validation_split_name, 101 parameters=parameters, 102 system_message=system_message, 103 parent=parent_task, 104 ) 105 106 adapter = cls(datamodel) 107 await adapter._start(dataset) 108 109 datamodel.save_to_file() 110 111 return adapter, datamodel 112 113 @abstractmethod 114 async def _start(self, dataset: DatasetSplit) -> None: 115 """ 116 Start the fine-tune. 117 """ 118 pass 119 120 @abstractmethod 121 async def status(self) -> FineTuneStatus: 122 """ 123 Get the status of the fine-tune. 124 """ 125 pass 126 127 @classmethod 128 def available_parameters(cls) -> list[FineTuneParameter]: 129 """ 130 Returns a list of parameters that can be provided for this fine-tune. Includes hyperparameters, etc. 131 """ 132 return [] 133 134 @classmethod 135 def validate_parameters( 136 cls, parameters: dict[str, str | int | float | bool] 137 ) -> None: 138 """ 139 Validate the parameters for this fine-tune. 140 """ 141 # Check required parameters and parameter types 142 available_parameters = cls.available_parameters() 143 for parameter in available_parameters: 144 if not parameter.optional and parameter.name not in parameters: 145 raise ValueError(f"Parameter {parameter.name} is required") 146 elif parameter.name in parameters: 147 # check parameter is correct type 148 expected_type = TYPE_MAP[parameter.type] 149 value = parameters[parameter.name] 150 151 # Strict type checking for numeric types 152 if expected_type is float and not isinstance(value, float): 153 raise ValueError( 154 f"Parameter {parameter.name} must be a float, got {type(value)}" 155 ) 156 elif expected_type is int and not isinstance(value, int): 157 raise ValueError( 158 f"Parameter {parameter.name} must be an integer, got {type(value)}" 159 ) 160 elif not isinstance(value, expected_type): 161 raise ValueError( 162 f"Parameter {parameter.name} must be type {expected_type}, got {type(value)}" 163 ) 164 165 allowed_parameters = [p.name for p in available_parameters] 166 for parameter_key in parameters: 167 if parameter_key not in allowed_parameters: 168 raise ValueError(f"Parameter {parameter_key} is not available") 169 170 @classmethod 171 def check_valid_provider_model( 172 cls, provider_id: str, provider_base_model_id: str 173 ) -> None: 174 """ 175 Check if the provider and base model are valid. 176 """ 177 for model in built_in_models: 178 for provider in model.providers: 179 if ( 180 provider.name == provider_id 181 and provider.provider_finetune_id == provider_base_model_id 182 ): 183 return 184 raise ValueError( 185 f"Provider {provider_id} with base model {provider_base_model_id} is not available" 186 )
class
FineTuneStatus(pydantic.main.BaseModel):
13class FineTuneStatus(BaseModel): 14 """ 15 The status of a fine-tune, including a user friendly message. 16 """ 17 18 status: FineTuneStatusType 19 message: str | None = None
The status of a fine-tune, including a user friendly message.
class
FineTuneParameter(pydantic.main.BaseModel):
22class FineTuneParameter(BaseModel): 23 """ 24 A parameter for a fine-tune. Hyperparameters, etc. 25 """ 26 27 name: str 28 type: Literal["string", "int", "float", "bool"] 29 description: str 30 optional: bool = True
A parameter for a fine-tune. Hyperparameters, etc.
TYPE_MAP =
{'string': <class 'str'>, 'int': <class 'int'>, 'float': <class 'float'>, 'bool': <class 'bool'>}
class
BaseFinetuneAdapter(abc.ABC):
41class BaseFinetuneAdapter(ABC): 42 """ 43 A base class for fine-tuning adapters. 44 """ 45 46 def __init__( 47 self, 48 datamodel: FinetuneModel, 49 ): 50 self.datamodel = datamodel 51 52 @classmethod 53 async def create_and_start( 54 cls, 55 dataset: DatasetSplit, 56 provider_id: str, 57 provider_base_model_id: str, 58 train_split_name: str, 59 system_message: str, 60 parameters: dict[str, str | int | float | bool] = {}, 61 name: str | None = None, 62 description: str | None = None, 63 validation_split_name: str | None = None, 64 ) -> tuple["BaseFinetuneAdapter", FinetuneModel]: 65 """ 66 Create and start a fine-tune. 67 """ 68 69 cls.check_valid_provider_model(provider_id, provider_base_model_id) 70 71 if not dataset.id: 72 raise ValueError("Dataset must have an id") 73 74 if train_split_name not in dataset.split_contents: 75 raise ValueError(f"Train split {train_split_name} not found in dataset") 76 77 if ( 78 validation_split_name 79 and validation_split_name not in dataset.split_contents 80 ): 81 raise ValueError( 82 f"Validation split {validation_split_name} not found in dataset" 83 ) 84 85 # Default name if not provided 86 if name is None: 87 name = generate_memorable_name() 88 89 cls.validate_parameters(parameters) 90 parent_task = dataset.parent_task() 91 if parent_task is None or not parent_task.path: 92 raise ValueError("Dataset must have a parent task with a path") 93 94 datamodel = FinetuneModel( 95 name=name, 96 description=description, 97 provider=provider_id, 98 base_model_id=provider_base_model_id, 99 dataset_split_id=dataset.id, 100 train_split_name=train_split_name, 101 validation_split_name=validation_split_name, 102 parameters=parameters, 103 system_message=system_message, 104 parent=parent_task, 105 ) 106 107 adapter = cls(datamodel) 108 await adapter._start(dataset) 109 110 datamodel.save_to_file() 111 112 return adapter, datamodel 113 114 @abstractmethod 115 async def _start(self, dataset: DatasetSplit) -> None: 116 """ 117 Start the fine-tune. 118 """ 119 pass 120 121 @abstractmethod 122 async def status(self) -> FineTuneStatus: 123 """ 124 Get the status of the fine-tune. 125 """ 126 pass 127 128 @classmethod 129 def available_parameters(cls) -> list[FineTuneParameter]: 130 """ 131 Returns a list of parameters that can be provided for this fine-tune. Includes hyperparameters, etc. 132 """ 133 return [] 134 135 @classmethod 136 def validate_parameters( 137 cls, parameters: dict[str, str | int | float | bool] 138 ) -> None: 139 """ 140 Validate the parameters for this fine-tune. 141 """ 142 # Check required parameters and parameter types 143 available_parameters = cls.available_parameters() 144 for parameter in available_parameters: 145 if not parameter.optional and parameter.name not in parameters: 146 raise ValueError(f"Parameter {parameter.name} is required") 147 elif parameter.name in parameters: 148 # check parameter is correct type 149 expected_type = TYPE_MAP[parameter.type] 150 value = parameters[parameter.name] 151 152 # Strict type checking for numeric types 153 if expected_type is float and not isinstance(value, float): 154 raise ValueError( 155 f"Parameter {parameter.name} must be a float, got {type(value)}" 156 ) 157 elif expected_type is int and not isinstance(value, int): 158 raise ValueError( 159 f"Parameter {parameter.name} must be an integer, got {type(value)}" 160 ) 161 elif not isinstance(value, expected_type): 162 raise ValueError( 163 f"Parameter {parameter.name} must be type {expected_type}, got {type(value)}" 164 ) 165 166 allowed_parameters = [p.name for p in available_parameters] 167 for parameter_key in parameters: 168 if parameter_key not in allowed_parameters: 169 raise ValueError(f"Parameter {parameter_key} is not available") 170 171 @classmethod 172 def check_valid_provider_model( 173 cls, provider_id: str, provider_base_model_id: str 174 ) -> None: 175 """ 176 Check if the provider and base model are valid. 177 """ 178 for model in built_in_models: 179 for provider in model.providers: 180 if ( 181 provider.name == provider_id 182 and provider.provider_finetune_id == provider_base_model_id 183 ): 184 return 185 raise ValueError( 186 f"Provider {provider_id} with base model {provider_base_model_id} is not available" 187 )
A base class for fine-tuning adapters.
@classmethod
async def
create_and_start( cls, dataset: kiln_ai.datamodel.DatasetSplit, provider_id: str, provider_base_model_id: str, train_split_name: str, system_message: str, parameters: dict[str, str | int | float | bool] = {}, name: str | None = None, description: str | None = None, validation_split_name: str | None = None) -> tuple[BaseFinetuneAdapter, kiln_ai.datamodel.Finetune]:
52 @classmethod 53 async def create_and_start( 54 cls, 55 dataset: DatasetSplit, 56 provider_id: str, 57 provider_base_model_id: str, 58 train_split_name: str, 59 system_message: str, 60 parameters: dict[str, str | int | float | bool] = {}, 61 name: str | None = None, 62 description: str | None = None, 63 validation_split_name: str | None = None, 64 ) -> tuple["BaseFinetuneAdapter", FinetuneModel]: 65 """ 66 Create and start a fine-tune. 67 """ 68 69 cls.check_valid_provider_model(provider_id, provider_base_model_id) 70 71 if not dataset.id: 72 raise ValueError("Dataset must have an id") 73 74 if train_split_name not in dataset.split_contents: 75 raise ValueError(f"Train split {train_split_name} not found in dataset") 76 77 if ( 78 validation_split_name 79 and validation_split_name not in dataset.split_contents 80 ): 81 raise ValueError( 82 f"Validation split {validation_split_name} not found in dataset" 83 ) 84 85 # Default name if not provided 86 if name is None: 87 name = generate_memorable_name() 88 89 cls.validate_parameters(parameters) 90 parent_task = dataset.parent_task() 91 if parent_task is None or not parent_task.path: 92 raise ValueError("Dataset must have a parent task with a path") 93 94 datamodel = FinetuneModel( 95 name=name, 96 description=description, 97 provider=provider_id, 98 base_model_id=provider_base_model_id, 99 dataset_split_id=dataset.id, 100 train_split_name=train_split_name, 101 validation_split_name=validation_split_name, 102 parameters=parameters, 103 system_message=system_message, 104 parent=parent_task, 105 ) 106 107 adapter = cls(datamodel) 108 await adapter._start(dataset) 109 110 datamodel.save_to_file() 111 112 return adapter, datamodel
Create and start a fine-tune.
121 @abstractmethod 122 async def status(self) -> FineTuneStatus: 123 """ 124 Get the status of the fine-tune. 125 """ 126 pass
Get the status of the fine-tune.
128 @classmethod 129 def available_parameters(cls) -> list[FineTuneParameter]: 130 """ 131 Returns a list of parameters that can be provided for this fine-tune. Includes hyperparameters, etc. 132 """ 133 return []
Returns a list of parameters that can be provided for this fine-tune. Includes hyperparameters, etc.
@classmethod
def
validate_parameters(cls, parameters: dict[str, str | int | float | bool]) -> None:
135 @classmethod 136 def validate_parameters( 137 cls, parameters: dict[str, str | int | float | bool] 138 ) -> None: 139 """ 140 Validate the parameters for this fine-tune. 141 """ 142 # Check required parameters and parameter types 143 available_parameters = cls.available_parameters() 144 for parameter in available_parameters: 145 if not parameter.optional and parameter.name not in parameters: 146 raise ValueError(f"Parameter {parameter.name} is required") 147 elif parameter.name in parameters: 148 # check parameter is correct type 149 expected_type = TYPE_MAP[parameter.type] 150 value = parameters[parameter.name] 151 152 # Strict type checking for numeric types 153 if expected_type is float and not isinstance(value, float): 154 raise ValueError( 155 f"Parameter {parameter.name} must be a float, got {type(value)}" 156 ) 157 elif expected_type is int and not isinstance(value, int): 158 raise ValueError( 159 f"Parameter {parameter.name} must be an integer, got {type(value)}" 160 ) 161 elif not isinstance(value, expected_type): 162 raise ValueError( 163 f"Parameter {parameter.name} must be type {expected_type}, got {type(value)}" 164 ) 165 166 allowed_parameters = [p.name for p in available_parameters] 167 for parameter_key in parameters: 168 if parameter_key not in allowed_parameters: 169 raise ValueError(f"Parameter {parameter_key} is not available")
Validate the parameters for this fine-tune.
@classmethod
def
check_valid_provider_model(cls, provider_id: str, provider_base_model_id: str) -> None:
171 @classmethod 172 def check_valid_provider_model( 173 cls, provider_id: str, provider_base_model_id: str 174 ) -> None: 175 """ 176 Check if the provider and base model are valid. 177 """ 178 for model in built_in_models: 179 for provider in model.providers: 180 if ( 181 provider.name == provider_id 182 and provider.provider_finetune_id == provider_base_model_id 183 ): 184 return 185 raise ValueError( 186 f"Provider {provider_id} with base model {provider_base_model_id} is not available" 187 )
Check if the provider and base model are valid.