kiln_ai.adapters.remote_config

  1import argparse
  2import asyncio
  3import json
  4import logging
  5import os
  6import threading
  7from dataclasses import dataclass
  8from pathlib import Path
  9from typing import Any, List
 10
 11import httpx
 12from pydantic import ValidationError
 13
 14from kiln_ai.adapters.ml_embedding_model_list import (
 15    KilnEmbeddingModel,
 16    KilnEmbeddingModelProvider,
 17    built_in_embedding_models,
 18)
 19from kiln_ai.adapters.reranker_list import (
 20    KilnRerankerModel,
 21    KilnRerankerModelProvider,
 22    built_in_rerankers,
 23)
 24from kiln_ai.datamodel.datamodel_enums import KilnMimeType
 25
 26from .ml_model_list import KilnModel, KilnModelProvider, built_in_models
 27
 28logger = logging.getLogger(__name__)
 29
 30# Loads github pages hosted JSON config.
 31# You can see public config build logs here: https://github.com/Kiln-AI/remote_config/actions/workflows/publish_remote_config.yml
 32# Content is hosted on Github Pages: https://kiln-ai.github.io/remote_config/kiln_config_v1.json
 33# V2 explained: Kiln v0.18 was the first release with remote config, but had bugs. We no longer publish v1 URL (client falls back to local) and instead use v2.
 34REMOTE_MODEL_LIST_URL = "https://remote-config.getkiln.ai/kiln_config_v2.json"
 35
 36refresh_lock = threading.Lock()
 37
 38
 39def should_skip_remote_model_list() -> bool:
 40    """Check if remote model list fetching should be skipped."""
 41    return os.environ.get("KILN_SKIP_REMOTE_MODEL_LIST") == "true"
 42
 43
 44@dataclass
 45class KilnRemoteConfig:
 46    model_list: List[KilnModel]
 47    embedding_model_list: List[KilnEmbeddingModel]
 48    reranker_model_list: List[KilnRerankerModel]
 49
 50
 51def serialize_config(
 52    models: List[KilnModel],
 53    embedding_models: List[KilnEmbeddingModel],
 54    reranker_models: List[KilnRerankerModel],
 55    path: str | Path,
 56) -> None:
 57    data = {
 58        "model_list": [m.model_dump(mode="json") for m in models],
 59        "embedding_model_list": [m.model_dump(mode="json") for m in embedding_models],
 60        "reranker_model_list": [m.model_dump(mode="json") for m in reranker_models],
 61    }
 62    Path(path).write_text(json.dumps(data, indent=2, sort_keys=True))
 63
 64
 65def deserialize_config_at_path(
 66    path: str | Path,
 67) -> KilnRemoteConfig:
 68    raw = json.loads(Path(path).read_text())
 69    return deserialize_config_data(raw)
 70
 71
 72def deserialize_config_data(
 73    config_data: Any,
 74) -> KilnRemoteConfig:
 75    if not isinstance(config_data, dict):
 76        raise ValueError(f"Remote config expected dict, got {type(config_data)}")
 77
 78    model_list = config_data.get("model_list", None)
 79    if not isinstance(model_list, list):
 80        raise ValueError(
 81            f"Remote config expected list of models, got {type(model_list)}"
 82        )
 83
 84    embedding_model_data = config_data.get("embedding_model_list", [])
 85    if not isinstance(embedding_model_data, list):
 86        raise ValueError(
 87            f"Remote config expected list of embedding models, got {type(embedding_model_data)}"
 88        )
 89
 90    reranker_model_data = config_data.get("reranker_model_list", [])
 91    if not isinstance(reranker_model_data, list):
 92        raise ValueError(
 93            f"Remote config expected list of reranker models, got {type(reranker_model_data)}"
 94        )
 95
 96    # We must be careful here, because some of the JSON data may be generated from a forward
 97    # version of the code that has newer fields / versions of the fields, that may cause
 98    # the current client this code is running on to fail to validate the item into a KilnModel.
 99    models = []
100    for model_data in model_list:
101        # We skip any model that fails validation - the models that the client can support
102        # will be pulled from the remote config, but the user will need to update their
103        # client to the latest version to see the newer models that break backwards compatibility.
104        try:
105            providers_list = model_data.get("providers", [])
106
107            providers = []
108            for provider_data in providers_list:
109                try:
110                    # we filter out the mime types that we don't support
111                    mime_types = provider_data.get("multimodal_mime_types")
112                    if mime_types is not None:
113                        provider_data["multimodal_mime_types"] = [
114                            mime_type
115                            for mime_type in mime_types
116                            if mime_type in list(KilnMimeType)
117                        ]
118                    provider = KilnModelProvider.model_validate(provider_data)
119                    providers.append(provider)
120                except ValidationError as e:
121                    logger.warning(
122                        "Failed to validate a model provider from remote config. Upgrade Kiln to use this model. Details %s: %s",
123                        provider_data,
124                        e,
125                    )
126
127            # this ensures the model deserialization won't fail because of a bad provider
128            model_data["providers"] = []
129
130            # now we validate the model without its providers
131            model = KilnModel.model_validate(model_data)
132
133            # and we attach back the providers that passed our validation
134            model.providers = providers
135            models.append(model)
136        except ValidationError as e:
137            logger.warning(
138                "Failed to validate a model from remote config. Upgrade Kiln to use this model. Details %s: %s",
139                model_data,
140                e,
141            )
142
143    embedding_models = []
144    for embedding_model_data in embedding_model_data:
145        try:
146            provider_list = embedding_model_data.get("providers", [])
147            providers = []
148            for provider_data in provider_list:
149                try:
150                    provider = KilnEmbeddingModelProvider.model_validate(provider_data)
151                    providers.append(provider)
152                except ValidationError as e:
153                    logger.warning(
154                        "Failed to validate an embedding model provider from remote config. Upgrade Kiln to use this model. Details %s: %s",
155                        provider_data,
156                        e,
157                    )
158
159            embedding_model_data["providers"] = []
160            embedding_model = KilnEmbeddingModel.model_validate(embedding_model_data)
161            embedding_model.providers = providers
162            embedding_models.append(embedding_model)
163        except ValidationError as e:
164            logger.warning(
165                "Failed to validate an embedding model from remote config. Upgrade Kiln to use this model. Details %s: %s",
166                embedding_model_data,
167                e,
168            )
169
170    reranker_models = []
171    for reranker_model_data in reranker_model_data:
172        try:
173            provider_list = reranker_model_data.get("providers", [])
174            providers = []
175            for provider_data in provider_list:
176                try:
177                    provider = KilnRerankerModelProvider.model_validate(provider_data)
178                    providers.append(provider)
179                except ValidationError as e:
180                    logger.warning(
181                        "Failed to validate a reranker model provider from remote config. Upgrade Kiln to use this model. Details %s: %s",
182                        provider_data,
183                        e,
184                    )
185
186            reranker_model_data["providers"] = []
187            reranker_model = KilnRerankerModel.model_validate(reranker_model_data)
188            reranker_model.providers = providers
189            reranker_models.append(reranker_model)
190        except ValidationError as e:
191            logger.warning(
192                "Failed to validate a reranker model from remote config. Upgrade Kiln to use this model. Details %s: %s",
193                reranker_model_data,
194                e,
195            )
196
197    return KilnRemoteConfig(
198        model_list=models,
199        embedding_model_list=embedding_models,
200        reranker_model_list=reranker_models,
201    )
202
203
204async def load_from_url(url: str) -> KilnRemoteConfig:
205    async with httpx.AsyncClient() as client:
206        response = await client.get(url, timeout=10.0)
207        response.raise_for_status()
208        data = response.json()
209        return deserialize_config_data(data)
210
211
212def dump_builtin_config(path: str | Path) -> None:
213    serialize_config(
214        models=built_in_models,
215        embedding_models=built_in_embedding_models,
216        reranker_models=built_in_rerankers,
217        path=path,
218    )
219
220
221async def refresh_model_list(url: str = REMOTE_MODEL_LIST_URL) -> None:
222    """Refresh the model list from a URL. This is not thread safe, only asyncio is safe.
223    If you call this from threads, make sure to wrap in an actual lock.
224
225    Args:
226        url: The URL to fetch the model list from. Defaults to REMOTE_MODEL_LIST_URL.
227    """
228    models = await load_from_url(url)
229    with refresh_lock:
230        built_in_models[:] = models.model_list
231        built_in_embedding_models[:] = models.embedding_model_list
232        built_in_rerankers[:] = models.reranker_model_list
233
234
235def refresh_model_list_background(
236    url: str = REMOTE_MODEL_LIST_URL,
237) -> threading.Thread:
238    """Refresh the model list in a background thread.
239
240    Args:
241        url: The URL to fetch the model list from. Defaults to REMOTE_MODEL_LIST_URL.
242
243    Returns:
244        The background thread, so callers can join() it if needed.
245    """
246
247    def run_async_in_thread() -> None:
248        try:
249            asyncio.run(refresh_model_list(url))
250        except Exception as exc:
251            logger.warning("Failed to fetch remote model list from %s: %s", url, exc)
252
253    thread = threading.Thread(target=run_async_in_thread, daemon=True)
254    thread.start()
255    return thread
256
257
258def main() -> None:
259    parser = argparse.ArgumentParser()
260    parser.add_argument("path", help="output path")
261    args = parser.parse_args()
262    dump_builtin_config(args.path)
263
264
265if __name__ == "__main__":
266    main()
logger = <Logger kiln_ai.adapters.remote_config (WARNING)>
REMOTE_MODEL_LIST_URL = 'https://remote-config.getkiln.ai/kiln_config_v2.json'
refresh_lock = <unlocked _thread.lock object>
def should_skip_remote_model_list() -> bool:
40def should_skip_remote_model_list() -> bool:
41    """Check if remote model list fetching should be skipped."""
42    return os.environ.get("KILN_SKIP_REMOTE_MODEL_LIST") == "true"

Check if remote model list fetching should be skipped.

@dataclass
class KilnRemoteConfig:
45@dataclass
46class KilnRemoteConfig:
47    model_list: List[KilnModel]
48    embedding_model_list: List[KilnEmbeddingModel]
49    reranker_model_list: List[KilnRerankerModel]
KilnRemoteConfig( model_list: List[kiln_ai.adapters.ml_model_list.KilnModel], embedding_model_list: List[kiln_ai.adapters.ml_embedding_model_list.KilnEmbeddingModel], reranker_model_list: List[kiln_ai.adapters.reranker_list.KilnRerankerModel])
reranker_model_list: List[kiln_ai.adapters.reranker_list.KilnRerankerModel]
def serialize_config( models: List[kiln_ai.adapters.ml_model_list.KilnModel], embedding_models: List[kiln_ai.adapters.ml_embedding_model_list.KilnEmbeddingModel], reranker_models: List[kiln_ai.adapters.reranker_list.KilnRerankerModel], path: str | pathlib._local.Path) -> None:
52def serialize_config(
53    models: List[KilnModel],
54    embedding_models: List[KilnEmbeddingModel],
55    reranker_models: List[KilnRerankerModel],
56    path: str | Path,
57) -> None:
58    data = {
59        "model_list": [m.model_dump(mode="json") for m in models],
60        "embedding_model_list": [m.model_dump(mode="json") for m in embedding_models],
61        "reranker_model_list": [m.model_dump(mode="json") for m in reranker_models],
62    }
63    Path(path).write_text(json.dumps(data, indent=2, sort_keys=True))
def deserialize_config_at_path( path: str | pathlib._local.Path) -> KilnRemoteConfig:
66def deserialize_config_at_path(
67    path: str | Path,
68) -> KilnRemoteConfig:
69    raw = json.loads(Path(path).read_text())
70    return deserialize_config_data(raw)
def deserialize_config_data(config_data: Any) -> KilnRemoteConfig:
 73def deserialize_config_data(
 74    config_data: Any,
 75) -> KilnRemoteConfig:
 76    if not isinstance(config_data, dict):
 77        raise ValueError(f"Remote config expected dict, got {type(config_data)}")
 78
 79    model_list = config_data.get("model_list", None)
 80    if not isinstance(model_list, list):
 81        raise ValueError(
 82            f"Remote config expected list of models, got {type(model_list)}"
 83        )
 84
 85    embedding_model_data = config_data.get("embedding_model_list", [])
 86    if not isinstance(embedding_model_data, list):
 87        raise ValueError(
 88            f"Remote config expected list of embedding models, got {type(embedding_model_data)}"
 89        )
 90
 91    reranker_model_data = config_data.get("reranker_model_list", [])
 92    if not isinstance(reranker_model_data, list):
 93        raise ValueError(
 94            f"Remote config expected list of reranker models, got {type(reranker_model_data)}"
 95        )
 96
 97    # We must be careful here, because some of the JSON data may be generated from a forward
 98    # version of the code that has newer fields / versions of the fields, that may cause
 99    # the current client this code is running on to fail to validate the item into a KilnModel.
100    models = []
101    for model_data in model_list:
102        # We skip any model that fails validation - the models that the client can support
103        # will be pulled from the remote config, but the user will need to update their
104        # client to the latest version to see the newer models that break backwards compatibility.
105        try:
106            providers_list = model_data.get("providers", [])
107
108            providers = []
109            for provider_data in providers_list:
110                try:
111                    # we filter out the mime types that we don't support
112                    mime_types = provider_data.get("multimodal_mime_types")
113                    if mime_types is not None:
114                        provider_data["multimodal_mime_types"] = [
115                            mime_type
116                            for mime_type in mime_types
117                            if mime_type in list(KilnMimeType)
118                        ]
119                    provider = KilnModelProvider.model_validate(provider_data)
120                    providers.append(provider)
121                except ValidationError as e:
122                    logger.warning(
123                        "Failed to validate a model provider from remote config. Upgrade Kiln to use this model. Details %s: %s",
124                        provider_data,
125                        e,
126                    )
127
128            # this ensures the model deserialization won't fail because of a bad provider
129            model_data["providers"] = []
130
131            # now we validate the model without its providers
132            model = KilnModel.model_validate(model_data)
133
134            # and we attach back the providers that passed our validation
135            model.providers = providers
136            models.append(model)
137        except ValidationError as e:
138            logger.warning(
139                "Failed to validate a model from remote config. Upgrade Kiln to use this model. Details %s: %s",
140                model_data,
141                e,
142            )
143
144    embedding_models = []
145    for embedding_model_data in embedding_model_data:
146        try:
147            provider_list = embedding_model_data.get("providers", [])
148            providers = []
149            for provider_data in provider_list:
150                try:
151                    provider = KilnEmbeddingModelProvider.model_validate(provider_data)
152                    providers.append(provider)
153                except ValidationError as e:
154                    logger.warning(
155                        "Failed to validate an embedding model provider from remote config. Upgrade Kiln to use this model. Details %s: %s",
156                        provider_data,
157                        e,
158                    )
159
160            embedding_model_data["providers"] = []
161            embedding_model = KilnEmbeddingModel.model_validate(embedding_model_data)
162            embedding_model.providers = providers
163            embedding_models.append(embedding_model)
164        except ValidationError as e:
165            logger.warning(
166                "Failed to validate an embedding model from remote config. Upgrade Kiln to use this model. Details %s: %s",
167                embedding_model_data,
168                e,
169            )
170
171    reranker_models = []
172    for reranker_model_data in reranker_model_data:
173        try:
174            provider_list = reranker_model_data.get("providers", [])
175            providers = []
176            for provider_data in provider_list:
177                try:
178                    provider = KilnRerankerModelProvider.model_validate(provider_data)
179                    providers.append(provider)
180                except ValidationError as e:
181                    logger.warning(
182                        "Failed to validate a reranker model provider from remote config. Upgrade Kiln to use this model. Details %s: %s",
183                        provider_data,
184                        e,
185                    )
186
187            reranker_model_data["providers"] = []
188            reranker_model = KilnRerankerModel.model_validate(reranker_model_data)
189            reranker_model.providers = providers
190            reranker_models.append(reranker_model)
191        except ValidationError as e:
192            logger.warning(
193                "Failed to validate a reranker model from remote config. Upgrade Kiln to use this model. Details %s: %s",
194                reranker_model_data,
195                e,
196            )
197
198    return KilnRemoteConfig(
199        model_list=models,
200        embedding_model_list=embedding_models,
201        reranker_model_list=reranker_models,
202    )
async def load_from_url(url: str) -> KilnRemoteConfig:
205async def load_from_url(url: str) -> KilnRemoteConfig:
206    async with httpx.AsyncClient() as client:
207        response = await client.get(url, timeout=10.0)
208        response.raise_for_status()
209        data = response.json()
210        return deserialize_config_data(data)
def dump_builtin_config(path: str | pathlib._local.Path) -> None:
213def dump_builtin_config(path: str | Path) -> None:
214    serialize_config(
215        models=built_in_models,
216        embedding_models=built_in_embedding_models,
217        reranker_models=built_in_rerankers,
218        path=path,
219    )
async def refresh_model_list( url: str = 'https://remote-config.getkiln.ai/kiln_config_v2.json') -> None:
222async def refresh_model_list(url: str = REMOTE_MODEL_LIST_URL) -> None:
223    """Refresh the model list from a URL. This is not thread safe, only asyncio is safe.
224    If you call this from threads, make sure to wrap in an actual lock.
225
226    Args:
227        url: The URL to fetch the model list from. Defaults to REMOTE_MODEL_LIST_URL.
228    """
229    models = await load_from_url(url)
230    with refresh_lock:
231        built_in_models[:] = models.model_list
232        built_in_embedding_models[:] = models.embedding_model_list
233        built_in_rerankers[:] = models.reranker_model_list

Refresh the model list from a URL. This is not thread safe, only asyncio is safe. If you call this from threads, make sure to wrap in an actual lock.

Args: url: The URL to fetch the model list from. Defaults to REMOTE_MODEL_LIST_URL.

def refresh_model_list_background( url: str = 'https://remote-config.getkiln.ai/kiln_config_v2.json') -> threading.Thread:
236def refresh_model_list_background(
237    url: str = REMOTE_MODEL_LIST_URL,
238) -> threading.Thread:
239    """Refresh the model list in a background thread.
240
241    Args:
242        url: The URL to fetch the model list from. Defaults to REMOTE_MODEL_LIST_URL.
243
244    Returns:
245        The background thread, so callers can join() it if needed.
246    """
247
248    def run_async_in_thread() -> None:
249        try:
250            asyncio.run(refresh_model_list(url))
251        except Exception as exc:
252            logger.warning("Failed to fetch remote model list from %s: %s", url, exc)
253
254    thread = threading.Thread(target=run_async_in_thread, daemon=True)
255    thread.start()
256    return thread

Refresh the model list in a background thread.

Args: url: The URL to fetch the model list from. Defaults to REMOTE_MODEL_LIST_URL.

Returns: The background thread, so callers can join() it if needed.

def main() -> None:
259def main() -> None:
260    parser = argparse.ArgumentParser()
261    parser.add_argument("path", help="output path")
262    args = parser.parse_args()
263    dump_builtin_config(args.path)