kiln_ai.adapters.fine_tune.base_finetune

  1from abc import ABC, abstractmethod
  2from typing import Literal
  3
  4from pydantic import BaseModel
  5
  6from kiln_ai.adapters.ml_model_list import built_in_models
  7from kiln_ai.datamodel import DatasetSplit, FineTuneStatusType
  8from kiln_ai.datamodel import Finetune as FinetuneModel
  9from kiln_ai.utils.name_generator import generate_memorable_name
 10
 11
 12class FineTuneStatus(BaseModel):
 13    """
 14    The status of a fine-tune, including a user friendly message.
 15    """
 16
 17    status: FineTuneStatusType
 18    message: str | None = None
 19
 20
 21class FineTuneParameter(BaseModel):
 22    """
 23    A parameter for a fine-tune. Hyperparameters, etc.
 24    """
 25
 26    name: str
 27    type: Literal["string", "int", "float", "bool"]
 28    description: str
 29    optional: bool = True
 30
 31
 32TYPE_MAP = {
 33    "string": str,
 34    "int": int,
 35    "float": float,
 36    "bool": bool,
 37}
 38
 39
 40class BaseFinetuneAdapter(ABC):
 41    """
 42    A base class for fine-tuning adapters.
 43    """
 44
 45    def __init__(
 46        self,
 47        datamodel: FinetuneModel,
 48    ):
 49        self.datamodel = datamodel
 50
 51    @classmethod
 52    async def create_and_start(
 53        cls,
 54        dataset: DatasetSplit,
 55        provider_id: str,
 56        provider_base_model_id: str,
 57        train_split_name: str,
 58        system_message: str,
 59        parameters: dict[str, str | int | float | bool] = {},
 60        name: str | None = None,
 61        description: str | None = None,
 62        validation_split_name: str | None = None,
 63    ) -> tuple["BaseFinetuneAdapter", FinetuneModel]:
 64        """
 65        Create and start a fine-tune.
 66        """
 67
 68        cls.check_valid_provider_model(provider_id, provider_base_model_id)
 69
 70        if not dataset.id:
 71            raise ValueError("Dataset must have an id")
 72
 73        if train_split_name not in dataset.split_contents:
 74            raise ValueError(f"Train split {train_split_name} not found in dataset")
 75
 76        if (
 77            validation_split_name
 78            and validation_split_name not in dataset.split_contents
 79        ):
 80            raise ValueError(
 81                f"Validation split {validation_split_name} not found in dataset"
 82            )
 83
 84        # Default name if not provided
 85        if name is None:
 86            name = generate_memorable_name()
 87
 88        cls.validate_parameters(parameters)
 89        parent_task = dataset.parent_task()
 90        if parent_task is None or not parent_task.path:
 91            raise ValueError("Dataset must have a parent task with a path")
 92
 93        datamodel = FinetuneModel(
 94            name=name,
 95            description=description,
 96            provider=provider_id,
 97            base_model_id=provider_base_model_id,
 98            dataset_split_id=dataset.id,
 99            train_split_name=train_split_name,
100            validation_split_name=validation_split_name,
101            parameters=parameters,
102            system_message=system_message,
103            parent=parent_task,
104        )
105
106        adapter = cls(datamodel)
107        await adapter._start(dataset)
108
109        datamodel.save_to_file()
110
111        return adapter, datamodel
112
113    @abstractmethod
114    async def _start(self, dataset: DatasetSplit) -> None:
115        """
116        Start the fine-tune.
117        """
118        pass
119
120    @abstractmethod
121    async def status(self) -> FineTuneStatus:
122        """
123        Get the status of the fine-tune.
124        """
125        pass
126
127    @classmethod
128    def available_parameters(cls) -> list[FineTuneParameter]:
129        """
130        Returns a list of parameters that can be provided for this fine-tune. Includes hyperparameters, etc.
131        """
132        return []
133
134    @classmethod
135    def validate_parameters(
136        cls, parameters: dict[str, str | int | float | bool]
137    ) -> None:
138        """
139        Validate the parameters for this fine-tune.
140        """
141        # Check required parameters and parameter types
142        available_parameters = cls.available_parameters()
143        for parameter in available_parameters:
144            if not parameter.optional and parameter.name not in parameters:
145                raise ValueError(f"Parameter {parameter.name} is required")
146            elif parameter.name in parameters:
147                # check parameter is correct type
148                expected_type = TYPE_MAP[parameter.type]
149                value = parameters[parameter.name]
150
151                # Strict type checking for numeric types
152                if expected_type is float and not isinstance(value, float):
153                    raise ValueError(
154                        f"Parameter {parameter.name} must be a float, got {type(value)}"
155                    )
156                elif expected_type is int and not isinstance(value, int):
157                    raise ValueError(
158                        f"Parameter {parameter.name} must be an integer, got {type(value)}"
159                    )
160                elif not isinstance(value, expected_type):
161                    raise ValueError(
162                        f"Parameter {parameter.name} must be type {expected_type}, got {type(value)}"
163                    )
164
165        allowed_parameters = [p.name for p in available_parameters]
166        for parameter_key in parameters:
167            if parameter_key not in allowed_parameters:
168                raise ValueError(f"Parameter {parameter_key} is not available")
169
170    @classmethod
171    def check_valid_provider_model(
172        cls, provider_id: str, provider_base_model_id: str
173    ) -> None:
174        """
175        Check if the provider and base model are valid.
176        """
177        for model in built_in_models:
178            for provider in model.providers:
179                if (
180                    provider.name == provider_id
181                    and provider.provider_finetune_id == provider_base_model_id
182                ):
183                    return
184        raise ValueError(
185            f"Provider {provider_id} with base model {provider_base_model_id} is not available"
186        )
class FineTuneStatus(pydantic.main.BaseModel):
13class FineTuneStatus(BaseModel):
14    """
15    The status of a fine-tune, including a user friendly message.
16    """
17
18    status: FineTuneStatusType
19    message: str | None = None

The status of a fine-tune, including a user friendly message.

message: str | None
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class FineTuneParameter(pydantic.main.BaseModel):
22class FineTuneParameter(BaseModel):
23    """
24    A parameter for a fine-tune. Hyperparameters, etc.
25    """
26
27    name: str
28    type: Literal["string", "int", "float", "bool"]
29    description: str
30    optional: bool = True

A parameter for a fine-tune. Hyperparameters, etc.

name: str
type: Literal['string', 'int', 'float', 'bool']
description: str
optional: bool
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

TYPE_MAP = {'string': <class 'str'>, 'int': <class 'int'>, 'float': <class 'float'>, 'bool': <class 'bool'>}
class BaseFinetuneAdapter(abc.ABC):
 41class BaseFinetuneAdapter(ABC):
 42    """
 43    A base class for fine-tuning adapters.
 44    """
 45
 46    def __init__(
 47        self,
 48        datamodel: FinetuneModel,
 49    ):
 50        self.datamodel = datamodel
 51
 52    @classmethod
 53    async def create_and_start(
 54        cls,
 55        dataset: DatasetSplit,
 56        provider_id: str,
 57        provider_base_model_id: str,
 58        train_split_name: str,
 59        system_message: str,
 60        parameters: dict[str, str | int | float | bool] = {},
 61        name: str | None = None,
 62        description: str | None = None,
 63        validation_split_name: str | None = None,
 64    ) -> tuple["BaseFinetuneAdapter", FinetuneModel]:
 65        """
 66        Create and start a fine-tune.
 67        """
 68
 69        cls.check_valid_provider_model(provider_id, provider_base_model_id)
 70
 71        if not dataset.id:
 72            raise ValueError("Dataset must have an id")
 73
 74        if train_split_name not in dataset.split_contents:
 75            raise ValueError(f"Train split {train_split_name} not found in dataset")
 76
 77        if (
 78            validation_split_name
 79            and validation_split_name not in dataset.split_contents
 80        ):
 81            raise ValueError(
 82                f"Validation split {validation_split_name} not found in dataset"
 83            )
 84
 85        # Default name if not provided
 86        if name is None:
 87            name = generate_memorable_name()
 88
 89        cls.validate_parameters(parameters)
 90        parent_task = dataset.parent_task()
 91        if parent_task is None or not parent_task.path:
 92            raise ValueError("Dataset must have a parent task with a path")
 93
 94        datamodel = FinetuneModel(
 95            name=name,
 96            description=description,
 97            provider=provider_id,
 98            base_model_id=provider_base_model_id,
 99            dataset_split_id=dataset.id,
100            train_split_name=train_split_name,
101            validation_split_name=validation_split_name,
102            parameters=parameters,
103            system_message=system_message,
104            parent=parent_task,
105        )
106
107        adapter = cls(datamodel)
108        await adapter._start(dataset)
109
110        datamodel.save_to_file()
111
112        return adapter, datamodel
113
114    @abstractmethod
115    async def _start(self, dataset: DatasetSplit) -> None:
116        """
117        Start the fine-tune.
118        """
119        pass
120
121    @abstractmethod
122    async def status(self) -> FineTuneStatus:
123        """
124        Get the status of the fine-tune.
125        """
126        pass
127
128    @classmethod
129    def available_parameters(cls) -> list[FineTuneParameter]:
130        """
131        Returns a list of parameters that can be provided for this fine-tune. Includes hyperparameters, etc.
132        """
133        return []
134
135    @classmethod
136    def validate_parameters(
137        cls, parameters: dict[str, str | int | float | bool]
138    ) -> None:
139        """
140        Validate the parameters for this fine-tune.
141        """
142        # Check required parameters and parameter types
143        available_parameters = cls.available_parameters()
144        for parameter in available_parameters:
145            if not parameter.optional and parameter.name not in parameters:
146                raise ValueError(f"Parameter {parameter.name} is required")
147            elif parameter.name in parameters:
148                # check parameter is correct type
149                expected_type = TYPE_MAP[parameter.type]
150                value = parameters[parameter.name]
151
152                # Strict type checking for numeric types
153                if expected_type is float and not isinstance(value, float):
154                    raise ValueError(
155                        f"Parameter {parameter.name} must be a float, got {type(value)}"
156                    )
157                elif expected_type is int and not isinstance(value, int):
158                    raise ValueError(
159                        f"Parameter {parameter.name} must be an integer, got {type(value)}"
160                    )
161                elif not isinstance(value, expected_type):
162                    raise ValueError(
163                        f"Parameter {parameter.name} must be type {expected_type}, got {type(value)}"
164                    )
165
166        allowed_parameters = [p.name for p in available_parameters]
167        for parameter_key in parameters:
168            if parameter_key not in allowed_parameters:
169                raise ValueError(f"Parameter {parameter_key} is not available")
170
171    @classmethod
172    def check_valid_provider_model(
173        cls, provider_id: str, provider_base_model_id: str
174    ) -> None:
175        """
176        Check if the provider and base model are valid.
177        """
178        for model in built_in_models:
179            for provider in model.providers:
180                if (
181                    provider.name == provider_id
182                    and provider.provider_finetune_id == provider_base_model_id
183                ):
184                    return
185        raise ValueError(
186            f"Provider {provider_id} with base model {provider_base_model_id} is not available"
187        )

A base class for fine-tuning adapters.

datamodel
@classmethod
async def create_and_start( cls, dataset: kiln_ai.datamodel.DatasetSplit, provider_id: str, provider_base_model_id: str, train_split_name: str, system_message: str, parameters: dict[str, str | int | float | bool] = {}, name: str | None = None, description: str | None = None, validation_split_name: str | None = None) -> tuple[BaseFinetuneAdapter, kiln_ai.datamodel.Finetune]:
 52    @classmethod
 53    async def create_and_start(
 54        cls,
 55        dataset: DatasetSplit,
 56        provider_id: str,
 57        provider_base_model_id: str,
 58        train_split_name: str,
 59        system_message: str,
 60        parameters: dict[str, str | int | float | bool] = {},
 61        name: str | None = None,
 62        description: str | None = None,
 63        validation_split_name: str | None = None,
 64    ) -> tuple["BaseFinetuneAdapter", FinetuneModel]:
 65        """
 66        Create and start a fine-tune.
 67        """
 68
 69        cls.check_valid_provider_model(provider_id, provider_base_model_id)
 70
 71        if not dataset.id:
 72            raise ValueError("Dataset must have an id")
 73
 74        if train_split_name not in dataset.split_contents:
 75            raise ValueError(f"Train split {train_split_name} not found in dataset")
 76
 77        if (
 78            validation_split_name
 79            and validation_split_name not in dataset.split_contents
 80        ):
 81            raise ValueError(
 82                f"Validation split {validation_split_name} not found in dataset"
 83            )
 84
 85        # Default name if not provided
 86        if name is None:
 87            name = generate_memorable_name()
 88
 89        cls.validate_parameters(parameters)
 90        parent_task = dataset.parent_task()
 91        if parent_task is None or not parent_task.path:
 92            raise ValueError("Dataset must have a parent task with a path")
 93
 94        datamodel = FinetuneModel(
 95            name=name,
 96            description=description,
 97            provider=provider_id,
 98            base_model_id=provider_base_model_id,
 99            dataset_split_id=dataset.id,
100            train_split_name=train_split_name,
101            validation_split_name=validation_split_name,
102            parameters=parameters,
103            system_message=system_message,
104            parent=parent_task,
105        )
106
107        adapter = cls(datamodel)
108        await adapter._start(dataset)
109
110        datamodel.save_to_file()
111
112        return adapter, datamodel

Create and start a fine-tune.

@abstractmethod
async def status(self) -> FineTuneStatus:
121    @abstractmethod
122    async def status(self) -> FineTuneStatus:
123        """
124        Get the status of the fine-tune.
125        """
126        pass

Get the status of the fine-tune.

@classmethod
def available_parameters(cls) -> list[FineTuneParameter]:
128    @classmethod
129    def available_parameters(cls) -> list[FineTuneParameter]:
130        """
131        Returns a list of parameters that can be provided for this fine-tune. Includes hyperparameters, etc.
132        """
133        return []

Returns a list of parameters that can be provided for this fine-tune. Includes hyperparameters, etc.

@classmethod
def validate_parameters(cls, parameters: dict[str, str | int | float | bool]) -> None:
135    @classmethod
136    def validate_parameters(
137        cls, parameters: dict[str, str | int | float | bool]
138    ) -> None:
139        """
140        Validate the parameters for this fine-tune.
141        """
142        # Check required parameters and parameter types
143        available_parameters = cls.available_parameters()
144        for parameter in available_parameters:
145            if not parameter.optional and parameter.name not in parameters:
146                raise ValueError(f"Parameter {parameter.name} is required")
147            elif parameter.name in parameters:
148                # check parameter is correct type
149                expected_type = TYPE_MAP[parameter.type]
150                value = parameters[parameter.name]
151
152                # Strict type checking for numeric types
153                if expected_type is float and not isinstance(value, float):
154                    raise ValueError(
155                        f"Parameter {parameter.name} must be a float, got {type(value)}"
156                    )
157                elif expected_type is int and not isinstance(value, int):
158                    raise ValueError(
159                        f"Parameter {parameter.name} must be an integer, got {type(value)}"
160                    )
161                elif not isinstance(value, expected_type):
162                    raise ValueError(
163                        f"Parameter {parameter.name} must be type {expected_type}, got {type(value)}"
164                    )
165
166        allowed_parameters = [p.name for p in available_parameters]
167        for parameter_key in parameters:
168            if parameter_key not in allowed_parameters:
169                raise ValueError(f"Parameter {parameter_key} is not available")

Validate the parameters for this fine-tune.

@classmethod
def check_valid_provider_model(cls, provider_id: str, provider_base_model_id: str) -> None:
171    @classmethod
172    def check_valid_provider_model(
173        cls, provider_id: str, provider_base_model_id: str
174    ) -> None:
175        """
176        Check if the provider and base model are valid.
177        """
178        for model in built_in_models:
179            for provider in model.providers:
180                if (
181                    provider.name == provider_id
182                    and provider.provider_finetune_id == provider_base_model_id
183                ):
184                    return
185        raise ValueError(
186            f"Provider {provider_id} with base model {provider_base_model_id} is not available"
187        )

Check if the provider and base model are valid.