kiln_ai.adapters.remote_config
1import argparse 2import asyncio 3import json 4import logging 5import os 6import threading 7from dataclasses import dataclass 8from pathlib import Path 9from typing import Any, List 10 11import httpx 12from pydantic import ValidationError 13 14from kiln_ai.adapters.ml_embedding_model_list import ( 15 KilnEmbeddingModel, 16 KilnEmbeddingModelProvider, 17 built_in_embedding_models, 18) 19from kiln_ai.adapters.reranker_list import ( 20 KilnRerankerModel, 21 KilnRerankerModelProvider, 22 built_in_rerankers, 23) 24from kiln_ai.datamodel.datamodel_enums import KilnMimeType 25 26from .ml_model_list import KilnModel, KilnModelProvider, built_in_models 27 28logger = logging.getLogger(__name__) 29 30# Loads github pages hosted JSON config. 31# You can see public config build logs here: https://github.com/Kiln-AI/remote_config/actions/workflows/publish_remote_config.yml 32# Content is hosted on Github Pages: https://kiln-ai.github.io/remote_config/kiln_config_v1.json 33# V2 explained: Kiln v0.18 was the first release with remote config, but had bugs. We no longer publish v1 URL (client falls back to local) and instead use v2. 34REMOTE_MODEL_LIST_URL = "https://remote-config.getkiln.ai/kiln_config_v2.json" 35 36refresh_lock = threading.Lock() 37 38 39def should_skip_remote_model_list() -> bool: 40 """Check if remote model list fetching should be skipped.""" 41 return os.environ.get("KILN_SKIP_REMOTE_MODEL_LIST") == "true" 42 43 44@dataclass 45class KilnRemoteConfig: 46 model_list: List[KilnModel] 47 embedding_model_list: List[KilnEmbeddingModel] 48 reranker_model_list: List[KilnRerankerModel] 49 50 51def serialize_config( 52 models: List[KilnModel], 53 embedding_models: List[KilnEmbeddingModel], 54 reranker_models: List[KilnRerankerModel], 55 path: str | Path, 56) -> None: 57 data = { 58 "model_list": [m.model_dump(mode="json") for m in models], 59 "embedding_model_list": [m.model_dump(mode="json") for m in embedding_models], 60 "reranker_model_list": [m.model_dump(mode="json") for m in reranker_models], 61 } 62 Path(path).write_text(json.dumps(data, indent=2, sort_keys=True)) 63 64 65def deserialize_config_at_path( 66 path: str | Path, 67) -> KilnRemoteConfig: 68 raw = json.loads(Path(path).read_text()) 69 return deserialize_config_data(raw) 70 71 72def deserialize_config_data( 73 config_data: Any, 74) -> KilnRemoteConfig: 75 if not isinstance(config_data, dict): 76 raise ValueError(f"Remote config expected dict, got {type(config_data)}") 77 78 model_list = config_data.get("model_list", None) 79 if not isinstance(model_list, list): 80 raise ValueError( 81 f"Remote config expected list of models, got {type(model_list)}" 82 ) 83 84 embedding_model_data = config_data.get("embedding_model_list", []) 85 if not isinstance(embedding_model_data, list): 86 raise ValueError( 87 f"Remote config expected list of embedding models, got {type(embedding_model_data)}" 88 ) 89 90 reranker_model_data = config_data.get("reranker_model_list", []) 91 if not isinstance(reranker_model_data, list): 92 raise ValueError( 93 f"Remote config expected list of reranker models, got {type(reranker_model_data)}" 94 ) 95 96 # We must be careful here, because some of the JSON data may be generated from a forward 97 # version of the code that has newer fields / versions of the fields, that may cause 98 # the current client this code is running on to fail to validate the item into a KilnModel. 99 models = [] 100 for model_data in model_list: 101 # We skip any model that fails validation - the models that the client can support 102 # will be pulled from the remote config, but the user will need to update their 103 # client to the latest version to see the newer models that break backwards compatibility. 104 try: 105 providers_list = model_data.get("providers", []) 106 107 providers = [] 108 for provider_data in providers_list: 109 try: 110 # we filter out the mime types that we don't support 111 mime_types = provider_data.get("multimodal_mime_types") 112 if mime_types is not None: 113 provider_data["multimodal_mime_types"] = [ 114 mime_type 115 for mime_type in mime_types 116 if mime_type in list(KilnMimeType) 117 ] 118 provider = KilnModelProvider.model_validate(provider_data) 119 providers.append(provider) 120 except ValidationError as e: 121 logger.warning( 122 "Failed to validate a model provider from remote config. Upgrade Kiln to use this model. Details %s: %s", 123 provider_data, 124 e, 125 ) 126 127 # this ensures the model deserialization won't fail because of a bad provider 128 model_data["providers"] = [] 129 130 # now we validate the model without its providers 131 model = KilnModel.model_validate(model_data) 132 133 # and we attach back the providers that passed our validation 134 model.providers = providers 135 models.append(model) 136 except ValidationError as e: 137 logger.warning( 138 "Failed to validate a model from remote config. Upgrade Kiln to use this model. Details %s: %s", 139 model_data, 140 e, 141 ) 142 143 embedding_models = [] 144 for embedding_model_data in embedding_model_data: 145 try: 146 provider_list = embedding_model_data.get("providers", []) 147 providers = [] 148 for provider_data in provider_list: 149 try: 150 provider = KilnEmbeddingModelProvider.model_validate(provider_data) 151 providers.append(provider) 152 except ValidationError as e: 153 logger.warning( 154 "Failed to validate an embedding model provider from remote config. Upgrade Kiln to use this model. Details %s: %s", 155 provider_data, 156 e, 157 ) 158 159 embedding_model_data["providers"] = [] 160 embedding_model = KilnEmbeddingModel.model_validate(embedding_model_data) 161 embedding_model.providers = providers 162 embedding_models.append(embedding_model) 163 except ValidationError as e: 164 logger.warning( 165 "Failed to validate an embedding model from remote config. Upgrade Kiln to use this model. Details %s: %s", 166 embedding_model_data, 167 e, 168 ) 169 170 reranker_models = [] 171 for reranker_model_data in reranker_model_data: 172 try: 173 provider_list = reranker_model_data.get("providers", []) 174 providers = [] 175 for provider_data in provider_list: 176 try: 177 provider = KilnRerankerModelProvider.model_validate(provider_data) 178 providers.append(provider) 179 except ValidationError as e: 180 logger.warning( 181 "Failed to validate a reranker model provider from remote config. Upgrade Kiln to use this model. Details %s: %s", 182 provider_data, 183 e, 184 ) 185 186 reranker_model_data["providers"] = [] 187 reranker_model = KilnRerankerModel.model_validate(reranker_model_data) 188 reranker_model.providers = providers 189 reranker_models.append(reranker_model) 190 except ValidationError as e: 191 logger.warning( 192 "Failed to validate a reranker model from remote config. Upgrade Kiln to use this model. Details %s: %s", 193 reranker_model_data, 194 e, 195 ) 196 197 return KilnRemoteConfig( 198 model_list=models, 199 embedding_model_list=embedding_models, 200 reranker_model_list=reranker_models, 201 ) 202 203 204async def load_from_url(url: str) -> KilnRemoteConfig: 205 async with httpx.AsyncClient() as client: 206 response = await client.get(url, timeout=10.0) 207 response.raise_for_status() 208 data = response.json() 209 return deserialize_config_data(data) 210 211 212def dump_builtin_config(path: str | Path) -> None: 213 serialize_config( 214 models=built_in_models, 215 embedding_models=built_in_embedding_models, 216 reranker_models=built_in_rerankers, 217 path=path, 218 ) 219 220 221async def refresh_model_list(url: str = REMOTE_MODEL_LIST_URL) -> None: 222 """Refresh the model list from a URL. This is not thread safe, only asyncio is safe. 223 If you call this from threads, make sure to wrap in an actual lock. 224 225 Args: 226 url: The URL to fetch the model list from. Defaults to REMOTE_MODEL_LIST_URL. 227 """ 228 models = await load_from_url(url) 229 with refresh_lock: 230 built_in_models[:] = models.model_list 231 built_in_embedding_models[:] = models.embedding_model_list 232 built_in_rerankers[:] = models.reranker_model_list 233 234 235def refresh_model_list_background( 236 url: str = REMOTE_MODEL_LIST_URL, 237) -> threading.Thread: 238 """Refresh the model list in a background thread. 239 240 Args: 241 url: The URL to fetch the model list from. Defaults to REMOTE_MODEL_LIST_URL. 242 243 Returns: 244 The background thread, so callers can join() it if needed. 245 """ 246 247 def run_async_in_thread() -> None: 248 try: 249 asyncio.run(refresh_model_list(url)) 250 except Exception as exc: 251 logger.warning("Failed to fetch remote model list from %s: %s", url, exc) 252 253 thread = threading.Thread(target=run_async_in_thread, daemon=True) 254 thread.start() 255 return thread 256 257 258def main() -> None: 259 parser = argparse.ArgumentParser() 260 parser.add_argument("path", help="output path") 261 args = parser.parse_args() 262 dump_builtin_config(args.path) 263 264 265if __name__ == "__main__": 266 main()
logger =
<Logger kiln_ai.adapters.remote_config (WARNING)>
REMOTE_MODEL_LIST_URL =
'https://remote-config.getkiln.ai/kiln_config_v2.json'
refresh_lock =
<unlocked _thread.lock object>
def
should_skip_remote_model_list() -> bool:
40def should_skip_remote_model_list() -> bool: 41 """Check if remote model list fetching should be skipped.""" 42 return os.environ.get("KILN_SKIP_REMOTE_MODEL_LIST") == "true"
Check if remote model list fetching should be skipped.
@dataclass
class
KilnRemoteConfig:
45@dataclass 46class KilnRemoteConfig: 47 model_list: List[KilnModel] 48 embedding_model_list: List[KilnEmbeddingModel] 49 reranker_model_list: List[KilnRerankerModel]
KilnRemoteConfig( model_list: List[kiln_ai.adapters.ml_model_list.KilnModel], embedding_model_list: List[kiln_ai.adapters.ml_embedding_model_list.KilnEmbeddingModel], reranker_model_list: List[kiln_ai.adapters.reranker_list.KilnRerankerModel])
model_list: List[kiln_ai.adapters.ml_model_list.KilnModel]
embedding_model_list: List[kiln_ai.adapters.ml_embedding_model_list.KilnEmbeddingModel]
def
serialize_config( models: List[kiln_ai.adapters.ml_model_list.KilnModel], embedding_models: List[kiln_ai.adapters.ml_embedding_model_list.KilnEmbeddingModel], reranker_models: List[kiln_ai.adapters.reranker_list.KilnRerankerModel], path: str | pathlib._local.Path) -> None:
52def serialize_config( 53 models: List[KilnModel], 54 embedding_models: List[KilnEmbeddingModel], 55 reranker_models: List[KilnRerankerModel], 56 path: str | Path, 57) -> None: 58 data = { 59 "model_list": [m.model_dump(mode="json") for m in models], 60 "embedding_model_list": [m.model_dump(mode="json") for m in embedding_models], 61 "reranker_model_list": [m.model_dump(mode="json") for m in reranker_models], 62 } 63 Path(path).write_text(json.dumps(data, indent=2, sort_keys=True))
73def deserialize_config_data( 74 config_data: Any, 75) -> KilnRemoteConfig: 76 if not isinstance(config_data, dict): 77 raise ValueError(f"Remote config expected dict, got {type(config_data)}") 78 79 model_list = config_data.get("model_list", None) 80 if not isinstance(model_list, list): 81 raise ValueError( 82 f"Remote config expected list of models, got {type(model_list)}" 83 ) 84 85 embedding_model_data = config_data.get("embedding_model_list", []) 86 if not isinstance(embedding_model_data, list): 87 raise ValueError( 88 f"Remote config expected list of embedding models, got {type(embedding_model_data)}" 89 ) 90 91 reranker_model_data = config_data.get("reranker_model_list", []) 92 if not isinstance(reranker_model_data, list): 93 raise ValueError( 94 f"Remote config expected list of reranker models, got {type(reranker_model_data)}" 95 ) 96 97 # We must be careful here, because some of the JSON data may be generated from a forward 98 # version of the code that has newer fields / versions of the fields, that may cause 99 # the current client this code is running on to fail to validate the item into a KilnModel. 100 models = [] 101 for model_data in model_list: 102 # We skip any model that fails validation - the models that the client can support 103 # will be pulled from the remote config, but the user will need to update their 104 # client to the latest version to see the newer models that break backwards compatibility. 105 try: 106 providers_list = model_data.get("providers", []) 107 108 providers = [] 109 for provider_data in providers_list: 110 try: 111 # we filter out the mime types that we don't support 112 mime_types = provider_data.get("multimodal_mime_types") 113 if mime_types is not None: 114 provider_data["multimodal_mime_types"] = [ 115 mime_type 116 for mime_type in mime_types 117 if mime_type in list(KilnMimeType) 118 ] 119 provider = KilnModelProvider.model_validate(provider_data) 120 providers.append(provider) 121 except ValidationError as e: 122 logger.warning( 123 "Failed to validate a model provider from remote config. Upgrade Kiln to use this model. Details %s: %s", 124 provider_data, 125 e, 126 ) 127 128 # this ensures the model deserialization won't fail because of a bad provider 129 model_data["providers"] = [] 130 131 # now we validate the model without its providers 132 model = KilnModel.model_validate(model_data) 133 134 # and we attach back the providers that passed our validation 135 model.providers = providers 136 models.append(model) 137 except ValidationError as e: 138 logger.warning( 139 "Failed to validate a model from remote config. Upgrade Kiln to use this model. Details %s: %s", 140 model_data, 141 e, 142 ) 143 144 embedding_models = [] 145 for embedding_model_data in embedding_model_data: 146 try: 147 provider_list = embedding_model_data.get("providers", []) 148 providers = [] 149 for provider_data in provider_list: 150 try: 151 provider = KilnEmbeddingModelProvider.model_validate(provider_data) 152 providers.append(provider) 153 except ValidationError as e: 154 logger.warning( 155 "Failed to validate an embedding model provider from remote config. Upgrade Kiln to use this model. Details %s: %s", 156 provider_data, 157 e, 158 ) 159 160 embedding_model_data["providers"] = [] 161 embedding_model = KilnEmbeddingModel.model_validate(embedding_model_data) 162 embedding_model.providers = providers 163 embedding_models.append(embedding_model) 164 except ValidationError as e: 165 logger.warning( 166 "Failed to validate an embedding model from remote config. Upgrade Kiln to use this model. Details %s: %s", 167 embedding_model_data, 168 e, 169 ) 170 171 reranker_models = [] 172 for reranker_model_data in reranker_model_data: 173 try: 174 provider_list = reranker_model_data.get("providers", []) 175 providers = [] 176 for provider_data in provider_list: 177 try: 178 provider = KilnRerankerModelProvider.model_validate(provider_data) 179 providers.append(provider) 180 except ValidationError as e: 181 logger.warning( 182 "Failed to validate a reranker model provider from remote config. Upgrade Kiln to use this model. Details %s: %s", 183 provider_data, 184 e, 185 ) 186 187 reranker_model_data["providers"] = [] 188 reranker_model = KilnRerankerModel.model_validate(reranker_model_data) 189 reranker_model.providers = providers 190 reranker_models.append(reranker_model) 191 except ValidationError as e: 192 logger.warning( 193 "Failed to validate a reranker model from remote config. Upgrade Kiln to use this model. Details %s: %s", 194 reranker_model_data, 195 e, 196 ) 197 198 return KilnRemoteConfig( 199 model_list=models, 200 embedding_model_list=embedding_models, 201 reranker_model_list=reranker_models, 202 )
def
dump_builtin_config(path: str | pathlib._local.Path) -> None:
async def
refresh_model_list( url: str = 'https://remote-config.getkiln.ai/kiln_config_v2.json') -> None:
222async def refresh_model_list(url: str = REMOTE_MODEL_LIST_URL) -> None: 223 """Refresh the model list from a URL. This is not thread safe, only asyncio is safe. 224 If you call this from threads, make sure to wrap in an actual lock. 225 226 Args: 227 url: The URL to fetch the model list from. Defaults to REMOTE_MODEL_LIST_URL. 228 """ 229 models = await load_from_url(url) 230 with refresh_lock: 231 built_in_models[:] = models.model_list 232 built_in_embedding_models[:] = models.embedding_model_list 233 built_in_rerankers[:] = models.reranker_model_list
Refresh the model list from a URL. This is not thread safe, only asyncio is safe. If you call this from threads, make sure to wrap in an actual lock.
Args: url: The URL to fetch the model list from. Defaults to REMOTE_MODEL_LIST_URL.
def
refresh_model_list_background( url: str = 'https://remote-config.getkiln.ai/kiln_config_v2.json') -> threading.Thread:
236def refresh_model_list_background( 237 url: str = REMOTE_MODEL_LIST_URL, 238) -> threading.Thread: 239 """Refresh the model list in a background thread. 240 241 Args: 242 url: The URL to fetch the model list from. Defaults to REMOTE_MODEL_LIST_URL. 243 244 Returns: 245 The background thread, so callers can join() it if needed. 246 """ 247 248 def run_async_in_thread() -> None: 249 try: 250 asyncio.run(refresh_model_list(url)) 251 except Exception as exc: 252 logger.warning("Failed to fetch remote model list from %s: %s", url, exc) 253 254 thread = threading.Thread(target=run_async_in_thread, daemon=True) 255 thread.start() 256 return thread
Refresh the model list in a background thread.
Args: url: The URL to fetch the model list from. Defaults to REMOTE_MODEL_LIST_URL.
Returns: The background thread, so callers can join() it if needed.
def
main() -> None: