kiln_ai.adapters.fine_tune.base_finetune

  1from abc import ABC, abstractmethod
  2from typing import Literal
  3
  4from pydantic import BaseModel
  5
  6from kiln_ai.datamodel import DatasetSplit, FineTuneStatusType, Task
  7from kiln_ai.datamodel import Finetune as FinetuneModel
  8from kiln_ai.datamodel.datamodel_enums import ChatStrategy
  9from kiln_ai.utils.name_generator import generate_memorable_name
 10
 11
 12class FineTuneStatus(BaseModel):
 13    """
 14    The status of a fine-tune, including a user friendly message.
 15    """
 16
 17    status: FineTuneStatusType
 18    message: str | None = None
 19    error_details: str | None = None
 20
 21
 22class FineTuneParameter(BaseModel):
 23    """
 24    A parameter for a fine-tune. Hyperparameters, etc.
 25    """
 26
 27    name: str
 28    type: Literal["string", "int", "float", "bool"]
 29    description: str
 30    optional: bool = True
 31
 32
 33TYPE_MAP = {
 34    "string": str,
 35    "int": int,
 36    "float": float,
 37    "bool": bool,
 38}
 39
 40
 41class BaseFinetuneAdapter(ABC):
 42    """
 43    A base class for fine-tuning adapters.
 44    """
 45
 46    def __init__(
 47        self,
 48        datamodel: FinetuneModel,
 49    ):
 50        self.datamodel = datamodel
 51
 52    @classmethod
 53    async def create_and_start(
 54        cls,
 55        dataset: DatasetSplit,
 56        provider_id: str,
 57        provider_base_model_id: str,
 58        train_split_name: str,
 59        system_message: str,
 60        thinking_instructions: str | None,
 61        data_strategy: ChatStrategy,
 62        parameters: dict[str, str | int | float | bool] = {},
 63        name: str | None = None,
 64        description: str | None = None,
 65        validation_split_name: str | None = None,
 66    ) -> tuple["BaseFinetuneAdapter", FinetuneModel]:
 67        """
 68        Create and start a fine-tune.
 69        """
 70
 71        if not dataset.id:
 72            raise ValueError("Dataset must have an id")
 73
 74        if train_split_name not in dataset.split_contents:
 75            raise ValueError(f"Train split {train_split_name} not found in dataset")
 76
 77        if (
 78            validation_split_name
 79            and validation_split_name not in dataset.split_contents
 80        ):
 81            raise ValueError(
 82                f"Validation split {validation_split_name} not found in dataset"
 83            )
 84
 85        # Default name if not provided
 86        if name is None:
 87            name = generate_memorable_name()
 88
 89        cls.validate_parameters(parameters)
 90        parent_task = dataset.parent_task()
 91        if parent_task is None or not parent_task.path:
 92            raise ValueError("Dataset must have a parent task with a path")
 93
 94        datamodel = FinetuneModel(
 95            name=name,
 96            description=description,
 97            provider=provider_id,
 98            base_model_id=provider_base_model_id,
 99            dataset_split_id=dataset.id,
100            train_split_name=train_split_name,
101            validation_split_name=validation_split_name,
102            parameters=parameters,
103            system_message=cls.augment_system_message(system_message, parent_task),
104            thinking_instructions=thinking_instructions,
105            parent=parent_task,
106            data_strategy=data_strategy,
107        )
108
109        adapter = cls(datamodel)
110        await adapter._start(dataset)
111
112        datamodel.save_to_file()
113
114        return adapter, datamodel
115
116    @classmethod
117    def augment_system_message(cls, system_message: str, task: Task) -> str:
118        """
119        Augment the system message with additional instructions, such as JSON instructions.
120        """
121
122        # Base implementation does nothing, can be overridden by subclasses
123        return system_message
124
125    @abstractmethod
126    async def _start(self, dataset: DatasetSplit) -> None:
127        """
128        Start the fine-tune.
129        """
130        pass
131
132    @abstractmethod
133    async def status(self) -> FineTuneStatus:
134        """
135        Get the status of the fine-tune.
136        """
137        pass
138
139    @classmethod
140    def available_parameters(cls) -> list[FineTuneParameter]:
141        """
142        Returns a list of parameters that can be provided for this fine-tune. Includes hyperparameters, etc.
143        """
144        return []
145
146    @classmethod
147    def validate_parameters(
148        cls, parameters: dict[str, str | int | float | bool]
149    ) -> None:
150        """
151        Validate the parameters for this fine-tune.
152        """
153        # Check required parameters and parameter types
154        available_parameters = cls.available_parameters()
155        for parameter in available_parameters:
156            if not parameter.optional and parameter.name not in parameters:
157                raise ValueError(f"Parameter {parameter.name} is required")
158            elif parameter.name in parameters:
159                # check parameter is correct type
160                expected_type = TYPE_MAP[parameter.type]
161                value = parameters[parameter.name]
162
163                # Strict type checking for numeric types
164                if expected_type is float and not isinstance(value, float):
165                    if isinstance(value, int):
166                        value = float(value)
167                    else:
168                        raise ValueError(
169                            f"Parameter {parameter.name} must be a float, got {type(value)}"
170                        )
171                elif expected_type is int and not isinstance(value, int):
172                    raise ValueError(
173                        f"Parameter {parameter.name} must be an integer, got {type(value)}"
174                    )
175                elif not isinstance(value, expected_type):
176                    raise ValueError(
177                        f"Parameter {parameter.name} must be type {expected_type}, got {type(value)}"
178                    )
179
180        allowed_parameters = [p.name for p in available_parameters]
181        for parameter_key in parameters:
182            if parameter_key not in allowed_parameters:
183                raise ValueError(f"Parameter {parameter_key} is not available")
class FineTuneStatus(pydantic.main.BaseModel):
13class FineTuneStatus(BaseModel):
14    """
15    The status of a fine-tune, including a user friendly message.
16    """
17
18    status: FineTuneStatusType
19    message: str | None = None
20    error_details: str | None = None

The status of a fine-tune, including a user friendly message.

message: str | None
error_details: str | None
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class FineTuneParameter(pydantic.main.BaseModel):
23class FineTuneParameter(BaseModel):
24    """
25    A parameter for a fine-tune. Hyperparameters, etc.
26    """
27
28    name: str
29    type: Literal["string", "int", "float", "bool"]
30    description: str
31    optional: bool = True

A parameter for a fine-tune. Hyperparameters, etc.

name: str
type: Literal['string', 'int', 'float', 'bool']
description: str
optional: bool
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

TYPE_MAP = {'string': <class 'str'>, 'int': <class 'int'>, 'float': <class 'float'>, 'bool': <class 'bool'>}
class BaseFinetuneAdapter(abc.ABC):
 42class BaseFinetuneAdapter(ABC):
 43    """
 44    A base class for fine-tuning adapters.
 45    """
 46
 47    def __init__(
 48        self,
 49        datamodel: FinetuneModel,
 50    ):
 51        self.datamodel = datamodel
 52
 53    @classmethod
 54    async def create_and_start(
 55        cls,
 56        dataset: DatasetSplit,
 57        provider_id: str,
 58        provider_base_model_id: str,
 59        train_split_name: str,
 60        system_message: str,
 61        thinking_instructions: str | None,
 62        data_strategy: ChatStrategy,
 63        parameters: dict[str, str | int | float | bool] = {},
 64        name: str | None = None,
 65        description: str | None = None,
 66        validation_split_name: str | None = None,
 67    ) -> tuple["BaseFinetuneAdapter", FinetuneModel]:
 68        """
 69        Create and start a fine-tune.
 70        """
 71
 72        if not dataset.id:
 73            raise ValueError("Dataset must have an id")
 74
 75        if train_split_name not in dataset.split_contents:
 76            raise ValueError(f"Train split {train_split_name} not found in dataset")
 77
 78        if (
 79            validation_split_name
 80            and validation_split_name not in dataset.split_contents
 81        ):
 82            raise ValueError(
 83                f"Validation split {validation_split_name} not found in dataset"
 84            )
 85
 86        # Default name if not provided
 87        if name is None:
 88            name = generate_memorable_name()
 89
 90        cls.validate_parameters(parameters)
 91        parent_task = dataset.parent_task()
 92        if parent_task is None or not parent_task.path:
 93            raise ValueError("Dataset must have a parent task with a path")
 94
 95        datamodel = FinetuneModel(
 96            name=name,
 97            description=description,
 98            provider=provider_id,
 99            base_model_id=provider_base_model_id,
100            dataset_split_id=dataset.id,
101            train_split_name=train_split_name,
102            validation_split_name=validation_split_name,
103            parameters=parameters,
104            system_message=cls.augment_system_message(system_message, parent_task),
105            thinking_instructions=thinking_instructions,
106            parent=parent_task,
107            data_strategy=data_strategy,
108        )
109
110        adapter = cls(datamodel)
111        await adapter._start(dataset)
112
113        datamodel.save_to_file()
114
115        return adapter, datamodel
116
117    @classmethod
118    def augment_system_message(cls, system_message: str, task: Task) -> str:
119        """
120        Augment the system message with additional instructions, such as JSON instructions.
121        """
122
123        # Base implementation does nothing, can be overridden by subclasses
124        return system_message
125
126    @abstractmethod
127    async def _start(self, dataset: DatasetSplit) -> None:
128        """
129        Start the fine-tune.
130        """
131        pass
132
133    @abstractmethod
134    async def status(self) -> FineTuneStatus:
135        """
136        Get the status of the fine-tune.
137        """
138        pass
139
140    @classmethod
141    def available_parameters(cls) -> list[FineTuneParameter]:
142        """
143        Returns a list of parameters that can be provided for this fine-tune. Includes hyperparameters, etc.
144        """
145        return []
146
147    @classmethod
148    def validate_parameters(
149        cls, parameters: dict[str, str | int | float | bool]
150    ) -> None:
151        """
152        Validate the parameters for this fine-tune.
153        """
154        # Check required parameters and parameter types
155        available_parameters = cls.available_parameters()
156        for parameter in available_parameters:
157            if not parameter.optional and parameter.name not in parameters:
158                raise ValueError(f"Parameter {parameter.name} is required")
159            elif parameter.name in parameters:
160                # check parameter is correct type
161                expected_type = TYPE_MAP[parameter.type]
162                value = parameters[parameter.name]
163
164                # Strict type checking for numeric types
165                if expected_type is float and not isinstance(value, float):
166                    if isinstance(value, int):
167                        value = float(value)
168                    else:
169                        raise ValueError(
170                            f"Parameter {parameter.name} must be a float, got {type(value)}"
171                        )
172                elif expected_type is int and not isinstance(value, int):
173                    raise ValueError(
174                        f"Parameter {parameter.name} must be an integer, got {type(value)}"
175                    )
176                elif not isinstance(value, expected_type):
177                    raise ValueError(
178                        f"Parameter {parameter.name} must be type {expected_type}, got {type(value)}"
179                    )
180
181        allowed_parameters = [p.name for p in available_parameters]
182        for parameter_key in parameters:
183            if parameter_key not in allowed_parameters:
184                raise ValueError(f"Parameter {parameter_key} is not available")

A base class for fine-tuning adapters.

datamodel
@classmethod
async def create_and_start( cls, dataset: kiln_ai.datamodel.DatasetSplit, provider_id: str, provider_base_model_id: str, train_split_name: str, system_message: str, thinking_instructions: str | None, data_strategy: kiln_ai.datamodel.datamodel_enums.ChatStrategy, parameters: dict[str, str | int | float | bool] = {}, name: str | None = None, description: str | None = None, validation_split_name: str | None = None) -> tuple[BaseFinetuneAdapter, kiln_ai.datamodel.Finetune]:
 53    @classmethod
 54    async def create_and_start(
 55        cls,
 56        dataset: DatasetSplit,
 57        provider_id: str,
 58        provider_base_model_id: str,
 59        train_split_name: str,
 60        system_message: str,
 61        thinking_instructions: str | None,
 62        data_strategy: ChatStrategy,
 63        parameters: dict[str, str | int | float | bool] = {},
 64        name: str | None = None,
 65        description: str | None = None,
 66        validation_split_name: str | None = None,
 67    ) -> tuple["BaseFinetuneAdapter", FinetuneModel]:
 68        """
 69        Create and start a fine-tune.
 70        """
 71
 72        if not dataset.id:
 73            raise ValueError("Dataset must have an id")
 74
 75        if train_split_name not in dataset.split_contents:
 76            raise ValueError(f"Train split {train_split_name} not found in dataset")
 77
 78        if (
 79            validation_split_name
 80            and validation_split_name not in dataset.split_contents
 81        ):
 82            raise ValueError(
 83                f"Validation split {validation_split_name} not found in dataset"
 84            )
 85
 86        # Default name if not provided
 87        if name is None:
 88            name = generate_memorable_name()
 89
 90        cls.validate_parameters(parameters)
 91        parent_task = dataset.parent_task()
 92        if parent_task is None or not parent_task.path:
 93            raise ValueError("Dataset must have a parent task with a path")
 94
 95        datamodel = FinetuneModel(
 96            name=name,
 97            description=description,
 98            provider=provider_id,
 99            base_model_id=provider_base_model_id,
100            dataset_split_id=dataset.id,
101            train_split_name=train_split_name,
102            validation_split_name=validation_split_name,
103            parameters=parameters,
104            system_message=cls.augment_system_message(system_message, parent_task),
105            thinking_instructions=thinking_instructions,
106            parent=parent_task,
107            data_strategy=data_strategy,
108        )
109
110        adapter = cls(datamodel)
111        await adapter._start(dataset)
112
113        datamodel.save_to_file()
114
115        return adapter, datamodel

Create and start a fine-tune.

@classmethod
def augment_system_message(cls, system_message: str, task: kiln_ai.datamodel.Task) -> str:
117    @classmethod
118    def augment_system_message(cls, system_message: str, task: Task) -> str:
119        """
120        Augment the system message with additional instructions, such as JSON instructions.
121        """
122
123        # Base implementation does nothing, can be overridden by subclasses
124        return system_message

Augment the system message with additional instructions, such as JSON instructions.

@abstractmethod
async def status(self) -> FineTuneStatus:
133    @abstractmethod
134    async def status(self) -> FineTuneStatus:
135        """
136        Get the status of the fine-tune.
137        """
138        pass

Get the status of the fine-tune.

@classmethod
def available_parameters(cls) -> list[FineTuneParameter]:
140    @classmethod
141    def available_parameters(cls) -> list[FineTuneParameter]:
142        """
143        Returns a list of parameters that can be provided for this fine-tune. Includes hyperparameters, etc.
144        """
145        return []

Returns a list of parameters that can be provided for this fine-tune. Includes hyperparameters, etc.

@classmethod
def validate_parameters(cls, parameters: dict[str, str | int | float | bool]) -> None:
147    @classmethod
148    def validate_parameters(
149        cls, parameters: dict[str, str | int | float | bool]
150    ) -> None:
151        """
152        Validate the parameters for this fine-tune.
153        """
154        # Check required parameters and parameter types
155        available_parameters = cls.available_parameters()
156        for parameter in available_parameters:
157            if not parameter.optional and parameter.name not in parameters:
158                raise ValueError(f"Parameter {parameter.name} is required")
159            elif parameter.name in parameters:
160                # check parameter is correct type
161                expected_type = TYPE_MAP[parameter.type]
162                value = parameters[parameter.name]
163
164                # Strict type checking for numeric types
165                if expected_type is float and not isinstance(value, float):
166                    if isinstance(value, int):
167                        value = float(value)
168                    else:
169                        raise ValueError(
170                            f"Parameter {parameter.name} must be a float, got {type(value)}"
171                        )
172                elif expected_type is int and not isinstance(value, int):
173                    raise ValueError(
174                        f"Parameter {parameter.name} must be an integer, got {type(value)}"
175                    )
176                elif not isinstance(value, expected_type):
177                    raise ValueError(
178                        f"Parameter {parameter.name} must be type {expected_type}, got {type(value)}"
179                    )
180
181        allowed_parameters = [p.name for p in available_parameters]
182        for parameter_key in parameters:
183            if parameter_key not in allowed_parameters:
184                raise ValueError(f"Parameter {parameter_key} is not available")

Validate the parameters for this fine-tune.