kiln_ai.adapters.fine_tune.base_finetune

  1from abc import ABC, abstractmethod
  2from typing import Literal
  3
  4from pydantic import BaseModel
  5
  6from kiln_ai.datamodel import DatasetSplit, FineTuneStatusType, Task
  7from kiln_ai.datamodel import Finetune as FinetuneModel
  8from kiln_ai.datamodel.datamodel_enums import (
  9    ChatStrategy,
 10    ModelProviderName,
 11)
 12from kiln_ai.datamodel.run_config import RunConfigProperties
 13from kiln_ai.utils.name_generator import generate_memorable_name
 14
 15
 16class FineTuneStatus(BaseModel):
 17    """
 18    The status of a fine-tune, including a user friendly message.
 19    """
 20
 21    status: FineTuneStatusType
 22    message: str | None = None
 23    error_details: str | None = None
 24
 25
 26class FineTuneParameter(BaseModel):
 27    """
 28    A parameter for a fine-tune. Hyperparameters, etc.
 29    """
 30
 31    name: str
 32    type: Literal["string", "int", "float", "bool"]
 33    description: str
 34    optional: bool = True
 35
 36
 37TYPE_MAP = {
 38    "string": str,
 39    "int": int,
 40    "float": float,
 41    "bool": bool,
 42}
 43
 44
 45class BaseFinetuneAdapter(ABC):
 46    """
 47    A base class for fine-tuning adapters.
 48    """
 49
 50    def __init__(
 51        self,
 52        datamodel: FinetuneModel,
 53    ):
 54        self.datamodel = datamodel
 55
 56    @classmethod
 57    async def create_and_start(
 58        cls,
 59        dataset: DatasetSplit,
 60        provider_id: str,
 61        provider_base_model_id: str,
 62        train_split_name: str,
 63        system_message: str,
 64        thinking_instructions: str | None,
 65        data_strategy: ChatStrategy,
 66        parameters: dict[str, str | int | float | bool] = {},
 67        name: str | None = None,
 68        description: str | None = None,
 69        validation_split_name: str | None = None,
 70        run_config: RunConfigProperties | None = None,
 71    ) -> tuple["BaseFinetuneAdapter", FinetuneModel]:
 72        """
 73        Create and start a fine-tune.
 74        """
 75
 76        if not dataset.id:
 77            raise ValueError("Dataset must have an id")
 78
 79        if train_split_name not in dataset.split_contents:
 80            raise ValueError(f"Train split {train_split_name} not found in dataset")
 81
 82        if (
 83            validation_split_name
 84            and validation_split_name not in dataset.split_contents
 85        ):
 86            raise ValueError(
 87                f"Validation split {validation_split_name} not found in dataset"
 88            )
 89
 90        # Raise exception if run config is none
 91        if run_config is None:
 92            raise ValueError("Run config is required")
 93
 94        # Default name if not provided
 95        if name is None:
 96            name = generate_memorable_name()
 97
 98        cls.validate_parameters(parameters)
 99        parent_task = dataset.parent_task()
100        if parent_task is None or not parent_task.path:
101            raise ValueError("Dataset must have a parent task with a path")
102
103        datamodel = FinetuneModel(
104            name=name,
105            description=description,
106            provider=provider_id,
107            base_model_id=provider_base_model_id,
108            dataset_split_id=dataset.id,
109            train_split_name=train_split_name,
110            validation_split_name=validation_split_name,
111            parameters=parameters,
112            system_message=cls.augment_system_message(system_message, parent_task),
113            thinking_instructions=thinking_instructions,
114            parent=parent_task,
115            data_strategy=data_strategy,
116            run_config=run_config,
117        )
118
119        # Update the run config properties for fine-tuning
120        run_config.model_provider_name = ModelProviderName.kiln_fine_tune
121        run_config.model_name = datamodel.nested_id()
122        run_config.prompt_id = f"fine_tune_prompt::{datamodel.nested_id()}"
123
124        adapter = cls(datamodel)
125        await adapter._start(dataset)
126
127        datamodel.save_to_file()
128
129        return adapter, datamodel
130
131    @classmethod
132    def augment_system_message(cls, system_message: str, task: Task) -> str:
133        """
134        Augment the system message with additional instructions, such as JSON instructions.
135        """
136
137        # Base implementation does nothing, can be overridden by subclasses
138        return system_message
139
140    @abstractmethod
141    async def _start(self, dataset: DatasetSplit) -> None:
142        """
143        Start the fine-tune.
144        """
145        pass
146
147    @abstractmethod
148    async def status(self) -> FineTuneStatus:
149        """
150        Get the status of the fine-tune.
151        """
152        pass
153
154    @classmethod
155    def available_parameters(cls) -> list[FineTuneParameter]:
156        """
157        Returns a list of parameters that can be provided for this fine-tune. Includes hyperparameters, etc.
158        """
159        return []
160
161    @classmethod
162    def validate_parameters(
163        cls, parameters: dict[str, str | int | float | bool]
164    ) -> None:
165        """
166        Validate the parameters for this fine-tune.
167        """
168        # Check required parameters and parameter types
169        available_parameters = cls.available_parameters()
170        for parameter in available_parameters:
171            if not parameter.optional and parameter.name not in parameters:
172                raise ValueError(f"Parameter {parameter.name} is required")
173            elif parameter.name in parameters:
174                # check parameter is correct type
175                expected_type = TYPE_MAP[parameter.type]
176                value = parameters[parameter.name]
177
178                # Strict type checking for numeric types
179                if expected_type is float and not isinstance(value, float):
180                    if isinstance(value, int):
181                        value = float(value)
182                    else:
183                        raise ValueError(
184                            f"Parameter {parameter.name} must be a float, got {type(value)}"
185                        )
186                elif expected_type is int and not isinstance(value, int):
187                    raise ValueError(
188                        f"Parameter {parameter.name} must be an integer, got {type(value)}"
189                    )
190                elif not isinstance(value, expected_type):
191                    raise ValueError(
192                        f"Parameter {parameter.name} must be type {expected_type}, got {type(value)}"
193                    )
194
195        allowed_parameters = [p.name for p in available_parameters]
196        for parameter_key in parameters:
197            if parameter_key not in allowed_parameters:
198                raise ValueError(f"Parameter {parameter_key} is not available")
class FineTuneStatus(pydantic.main.BaseModel):
17class FineTuneStatus(BaseModel):
18    """
19    The status of a fine-tune, including a user friendly message.
20    """
21
22    status: FineTuneStatusType
23    message: str | None = None
24    error_details: str | None = None

The status of a fine-tune, including a user friendly message.

message: str | None
error_details: str | None
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class FineTuneParameter(pydantic.main.BaseModel):
27class FineTuneParameter(BaseModel):
28    """
29    A parameter for a fine-tune. Hyperparameters, etc.
30    """
31
32    name: str
33    type: Literal["string", "int", "float", "bool"]
34    description: str
35    optional: bool = True

A parameter for a fine-tune. Hyperparameters, etc.

name: str
type: Literal['string', 'int', 'float', 'bool']
description: str
optional: bool
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

TYPE_MAP = {'string': <class 'str'>, 'int': <class 'int'>, 'float': <class 'float'>, 'bool': <class 'bool'>}
class BaseFinetuneAdapter(abc.ABC):
 46class BaseFinetuneAdapter(ABC):
 47    """
 48    A base class for fine-tuning adapters.
 49    """
 50
 51    def __init__(
 52        self,
 53        datamodel: FinetuneModel,
 54    ):
 55        self.datamodel = datamodel
 56
 57    @classmethod
 58    async def create_and_start(
 59        cls,
 60        dataset: DatasetSplit,
 61        provider_id: str,
 62        provider_base_model_id: str,
 63        train_split_name: str,
 64        system_message: str,
 65        thinking_instructions: str | None,
 66        data_strategy: ChatStrategy,
 67        parameters: dict[str, str | int | float | bool] = {},
 68        name: str | None = None,
 69        description: str | None = None,
 70        validation_split_name: str | None = None,
 71        run_config: RunConfigProperties | None = None,
 72    ) -> tuple["BaseFinetuneAdapter", FinetuneModel]:
 73        """
 74        Create and start a fine-tune.
 75        """
 76
 77        if not dataset.id:
 78            raise ValueError("Dataset must have an id")
 79
 80        if train_split_name not in dataset.split_contents:
 81            raise ValueError(f"Train split {train_split_name} not found in dataset")
 82
 83        if (
 84            validation_split_name
 85            and validation_split_name not in dataset.split_contents
 86        ):
 87            raise ValueError(
 88                f"Validation split {validation_split_name} not found in dataset"
 89            )
 90
 91        # Raise exception if run config is none
 92        if run_config is None:
 93            raise ValueError("Run config is required")
 94
 95        # Default name if not provided
 96        if name is None:
 97            name = generate_memorable_name()
 98
 99        cls.validate_parameters(parameters)
100        parent_task = dataset.parent_task()
101        if parent_task is None or not parent_task.path:
102            raise ValueError("Dataset must have a parent task with a path")
103
104        datamodel = FinetuneModel(
105            name=name,
106            description=description,
107            provider=provider_id,
108            base_model_id=provider_base_model_id,
109            dataset_split_id=dataset.id,
110            train_split_name=train_split_name,
111            validation_split_name=validation_split_name,
112            parameters=parameters,
113            system_message=cls.augment_system_message(system_message, parent_task),
114            thinking_instructions=thinking_instructions,
115            parent=parent_task,
116            data_strategy=data_strategy,
117            run_config=run_config,
118        )
119
120        # Update the run config properties for fine-tuning
121        run_config.model_provider_name = ModelProviderName.kiln_fine_tune
122        run_config.model_name = datamodel.nested_id()
123        run_config.prompt_id = f"fine_tune_prompt::{datamodel.nested_id()}"
124
125        adapter = cls(datamodel)
126        await adapter._start(dataset)
127
128        datamodel.save_to_file()
129
130        return adapter, datamodel
131
132    @classmethod
133    def augment_system_message(cls, system_message: str, task: Task) -> str:
134        """
135        Augment the system message with additional instructions, such as JSON instructions.
136        """
137
138        # Base implementation does nothing, can be overridden by subclasses
139        return system_message
140
141    @abstractmethod
142    async def _start(self, dataset: DatasetSplit) -> None:
143        """
144        Start the fine-tune.
145        """
146        pass
147
148    @abstractmethod
149    async def status(self) -> FineTuneStatus:
150        """
151        Get the status of the fine-tune.
152        """
153        pass
154
155    @classmethod
156    def available_parameters(cls) -> list[FineTuneParameter]:
157        """
158        Returns a list of parameters that can be provided for this fine-tune. Includes hyperparameters, etc.
159        """
160        return []
161
162    @classmethod
163    def validate_parameters(
164        cls, parameters: dict[str, str | int | float | bool]
165    ) -> None:
166        """
167        Validate the parameters for this fine-tune.
168        """
169        # Check required parameters and parameter types
170        available_parameters = cls.available_parameters()
171        for parameter in available_parameters:
172            if not parameter.optional and parameter.name not in parameters:
173                raise ValueError(f"Parameter {parameter.name} is required")
174            elif parameter.name in parameters:
175                # check parameter is correct type
176                expected_type = TYPE_MAP[parameter.type]
177                value = parameters[parameter.name]
178
179                # Strict type checking for numeric types
180                if expected_type is float and not isinstance(value, float):
181                    if isinstance(value, int):
182                        value = float(value)
183                    else:
184                        raise ValueError(
185                            f"Parameter {parameter.name} must be a float, got {type(value)}"
186                        )
187                elif expected_type is int and not isinstance(value, int):
188                    raise ValueError(
189                        f"Parameter {parameter.name} must be an integer, got {type(value)}"
190                    )
191                elif not isinstance(value, expected_type):
192                    raise ValueError(
193                        f"Parameter {parameter.name} must be type {expected_type}, got {type(value)}"
194                    )
195
196        allowed_parameters = [p.name for p in available_parameters]
197        for parameter_key in parameters:
198            if parameter_key not in allowed_parameters:
199                raise ValueError(f"Parameter {parameter_key} is not available")

A base class for fine-tuning adapters.

datamodel
@classmethod
async def create_and_start( cls, dataset: kiln_ai.datamodel.DatasetSplit, provider_id: str, provider_base_model_id: str, train_split_name: str, system_message: str, thinking_instructions: str | None, data_strategy: kiln_ai.datamodel.datamodel_enums.ChatStrategy, parameters: dict[str, str | int | float | bool] = {}, name: str | None = None, description: str | None = None, validation_split_name: str | None = None, run_config: kiln_ai.datamodel.run_config.RunConfigProperties | None = None) -> tuple[BaseFinetuneAdapter, kiln_ai.datamodel.Finetune]:
 57    @classmethod
 58    async def create_and_start(
 59        cls,
 60        dataset: DatasetSplit,
 61        provider_id: str,
 62        provider_base_model_id: str,
 63        train_split_name: str,
 64        system_message: str,
 65        thinking_instructions: str | None,
 66        data_strategy: ChatStrategy,
 67        parameters: dict[str, str | int | float | bool] = {},
 68        name: str | None = None,
 69        description: str | None = None,
 70        validation_split_name: str | None = None,
 71        run_config: RunConfigProperties | None = None,
 72    ) -> tuple["BaseFinetuneAdapter", FinetuneModel]:
 73        """
 74        Create and start a fine-tune.
 75        """
 76
 77        if not dataset.id:
 78            raise ValueError("Dataset must have an id")
 79
 80        if train_split_name not in dataset.split_contents:
 81            raise ValueError(f"Train split {train_split_name} not found in dataset")
 82
 83        if (
 84            validation_split_name
 85            and validation_split_name not in dataset.split_contents
 86        ):
 87            raise ValueError(
 88                f"Validation split {validation_split_name} not found in dataset"
 89            )
 90
 91        # Raise exception if run config is none
 92        if run_config is None:
 93            raise ValueError("Run config is required")
 94
 95        # Default name if not provided
 96        if name is None:
 97            name = generate_memorable_name()
 98
 99        cls.validate_parameters(parameters)
100        parent_task = dataset.parent_task()
101        if parent_task is None or not parent_task.path:
102            raise ValueError("Dataset must have a parent task with a path")
103
104        datamodel = FinetuneModel(
105            name=name,
106            description=description,
107            provider=provider_id,
108            base_model_id=provider_base_model_id,
109            dataset_split_id=dataset.id,
110            train_split_name=train_split_name,
111            validation_split_name=validation_split_name,
112            parameters=parameters,
113            system_message=cls.augment_system_message(system_message, parent_task),
114            thinking_instructions=thinking_instructions,
115            parent=parent_task,
116            data_strategy=data_strategy,
117            run_config=run_config,
118        )
119
120        # Update the run config properties for fine-tuning
121        run_config.model_provider_name = ModelProviderName.kiln_fine_tune
122        run_config.model_name = datamodel.nested_id()
123        run_config.prompt_id = f"fine_tune_prompt::{datamodel.nested_id()}"
124
125        adapter = cls(datamodel)
126        await adapter._start(dataset)
127
128        datamodel.save_to_file()
129
130        return adapter, datamodel

Create and start a fine-tune.

@classmethod
def augment_system_message(cls, system_message: str, task: kiln_ai.datamodel.Task) -> str:
132    @classmethod
133    def augment_system_message(cls, system_message: str, task: Task) -> str:
134        """
135        Augment the system message with additional instructions, such as JSON instructions.
136        """
137
138        # Base implementation does nothing, can be overridden by subclasses
139        return system_message

Augment the system message with additional instructions, such as JSON instructions.

@abstractmethod
async def status(self) -> FineTuneStatus:
148    @abstractmethod
149    async def status(self) -> FineTuneStatus:
150        """
151        Get the status of the fine-tune.
152        """
153        pass

Get the status of the fine-tune.

@classmethod
def available_parameters(cls) -> list[FineTuneParameter]:
155    @classmethod
156    def available_parameters(cls) -> list[FineTuneParameter]:
157        """
158        Returns a list of parameters that can be provided for this fine-tune. Includes hyperparameters, etc.
159        """
160        return []

Returns a list of parameters that can be provided for this fine-tune. Includes hyperparameters, etc.

@classmethod
def validate_parameters(cls, parameters: dict[str, str | int | float | bool]) -> None:
162    @classmethod
163    def validate_parameters(
164        cls, parameters: dict[str, str | int | float | bool]
165    ) -> None:
166        """
167        Validate the parameters for this fine-tune.
168        """
169        # Check required parameters and parameter types
170        available_parameters = cls.available_parameters()
171        for parameter in available_parameters:
172            if not parameter.optional and parameter.name not in parameters:
173                raise ValueError(f"Parameter {parameter.name} is required")
174            elif parameter.name in parameters:
175                # check parameter is correct type
176                expected_type = TYPE_MAP[parameter.type]
177                value = parameters[parameter.name]
178
179                # Strict type checking for numeric types
180                if expected_type is float and not isinstance(value, float):
181                    if isinstance(value, int):
182                        value = float(value)
183                    else:
184                        raise ValueError(
185                            f"Parameter {parameter.name} must be a float, got {type(value)}"
186                        )
187                elif expected_type is int and not isinstance(value, int):
188                    raise ValueError(
189                        f"Parameter {parameter.name} must be an integer, got {type(value)}"
190                    )
191                elif not isinstance(value, expected_type):
192                    raise ValueError(
193                        f"Parameter {parameter.name} must be type {expected_type}, got {type(value)}"
194                    )
195
196        allowed_parameters = [p.name for p in available_parameters]
197        for parameter_key in parameters:
198            if parameter_key not in allowed_parameters:
199                raise ValueError(f"Parameter {parameter_key} is not available")

Validate the parameters for this fine-tune.