kiln_ai.adapters.model_adapters.litellm_adapter
1import logging 2from typing import Any, Dict 3 4import litellm 5from litellm.types.utils import ChoiceLogprobs, Choices, ModelResponse 6from litellm.types.utils import Usage as LiteLlmUsage 7 8import kiln_ai.datamodel as datamodel 9from kiln_ai.adapters.ml_model_list import ( 10 KilnModelProvider, 11 ModelProviderName, 12 StructuredOutputMode, 13) 14from kiln_ai.adapters.model_adapters.base_adapter import ( 15 AdapterConfig, 16 BaseAdapter, 17 RunOutput, 18 Usage, 19) 20from kiln_ai.adapters.model_adapters.litellm_config import LiteLlmConfig 21from kiln_ai.datamodel.task import run_config_from_run_config_properties 22from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error 23 24logger = logging.getLogger(__name__) 25 26 27class LiteLlmAdapter(BaseAdapter): 28 def __init__( 29 self, 30 config: LiteLlmConfig, 31 kiln_task: datamodel.Task, 32 base_adapter_config: AdapterConfig | None = None, 33 ): 34 self.config = config 35 self._additional_body_options = config.additional_body_options 36 self._api_base = config.base_url 37 self._headers = config.default_headers 38 self._litellm_model_id: str | None = None 39 40 # Create a RunConfig, adding the task to the RunConfigProperties 41 run_config = run_config_from_run_config_properties( 42 task=kiln_task, 43 run_config_properties=config.run_config_properties, 44 ) 45 46 super().__init__( 47 run_config=run_config, 48 config=base_adapter_config, 49 ) 50 51 async def _run(self, input: Dict | str) -> tuple[RunOutput, Usage | None]: 52 provider = self.model_provider() 53 if not provider.model_id: 54 raise ValueError("Model ID is required for OpenAI compatible models") 55 56 chat_formatter = self.build_chat_formatter(input) 57 58 prior_output = None 59 prior_message = None 60 response = None 61 turns = 0 62 while True: 63 turns += 1 64 if turns > 10: 65 raise RuntimeError( 66 "Too many turns. Stopping iteration to avoid using too many tokens." 67 ) 68 69 turn = chat_formatter.next_turn(prior_output) 70 if turn is None: 71 break 72 73 skip_response_format = not turn.final_call 74 all_messages = chat_formatter.message_dicts() 75 completion_kwargs = await self.build_completion_kwargs( 76 provider, 77 all_messages, 78 self.base_adapter_config.top_logprobs if turn.final_call else None, 79 skip_response_format, 80 ) 81 response = await litellm.acompletion(**completion_kwargs) 82 if ( 83 not isinstance(response, ModelResponse) 84 or not response.choices 85 or len(response.choices) == 0 86 or not isinstance(response.choices[0], Choices) 87 ): 88 raise RuntimeError( 89 f"Expected ModelResponse with Choices, got {type(response)}." 90 ) 91 prior_message = response.choices[0].message 92 prior_output = prior_message.content 93 94 # Fallback: Use args of first tool call to task_response if it exists 95 if ( 96 not prior_output 97 and hasattr(prior_message, "tool_calls") 98 and prior_message.tool_calls 99 ): 100 tool_call = next( 101 ( 102 tool_call 103 for tool_call in prior_message.tool_calls 104 if tool_call.function.name == "task_response" 105 ), 106 None, 107 ) 108 if tool_call: 109 prior_output = tool_call.function.arguments 110 111 if not prior_output: 112 raise RuntimeError("No output returned from model") 113 114 if response is None or prior_message is None: 115 raise RuntimeError("No response returned from model") 116 117 intermediate_outputs = chat_formatter.intermediate_outputs() 118 119 logprobs = ( 120 response.choices[0].logprobs 121 if hasattr(response.choices[0], "logprobs") 122 and isinstance(response.choices[0].logprobs, ChoiceLogprobs) 123 else None 124 ) 125 126 # Check logprobs worked, if requested 127 if self.base_adapter_config.top_logprobs is not None and logprobs is None: 128 raise RuntimeError("Logprobs were required, but no logprobs were returned.") 129 130 # Save reasoning if it exists and was parsed by LiteLLM (or openrouter, or anyone upstream) 131 if ( 132 prior_message is not None 133 and hasattr(prior_message, "reasoning_content") 134 and prior_message.reasoning_content 135 and len(prior_message.reasoning_content.strip()) > 0 136 ): 137 intermediate_outputs["reasoning"] = prior_message.reasoning_content.strip() 138 139 # the string content of the response 140 response_content = prior_output 141 142 if not isinstance(response_content, str): 143 raise RuntimeError(f"response is not a string: {response_content}") 144 145 return RunOutput( 146 output=response_content, 147 intermediate_outputs=intermediate_outputs, 148 output_logprobs=logprobs, 149 ), self.usage_from_response(response) 150 151 def adapter_name(self) -> str: 152 return "kiln_openai_compatible_adapter" 153 154 async def response_format_options(self) -> dict[str, Any]: 155 # Unstructured if task isn't structured 156 if not self.has_structured_output(): 157 return {} 158 159 structured_output_mode = self.run_config.structured_output_mode 160 161 match structured_output_mode: 162 case StructuredOutputMode.json_mode: 163 return {"response_format": {"type": "json_object"}} 164 case StructuredOutputMode.json_schema: 165 return self.json_schema_response_format() 166 case StructuredOutputMode.function_calling_weak: 167 return self.tool_call_params(strict=False) 168 case StructuredOutputMode.function_calling: 169 return self.tool_call_params(strict=True) 170 case StructuredOutputMode.json_instructions: 171 # JSON instructions dynamically injected in prompt, not the API response format. Do not ask for json_object (see option below). 172 return {} 173 case StructuredOutputMode.json_custom_instructions: 174 # JSON instructions statically injected in system prompt, not the API response format. Do not ask for json_object (see option above). 175 return {} 176 case StructuredOutputMode.json_instruction_and_object: 177 # We set response_format to json_object and also set json instructions in the prompt 178 return {"response_format": {"type": "json_object"}} 179 case StructuredOutputMode.default: 180 provider_name = self.run_config.model_provider_name 181 if provider_name == ModelProviderName.ollama: 182 # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs 183 return self.json_schema_response_format() 184 else: 185 # Default to function calling -- it's older than the other modes. Higher compatibility. 186 # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI. 187 strict = provider_name == ModelProviderName.openai 188 return self.tool_call_params(strict=strict) 189 case StructuredOutputMode.unknown: 190 # See above, but this case should never happen. 191 raise ValueError("Structured output mode is unknown.") 192 case _: 193 raise_exhaustive_enum_error(structured_output_mode) 194 195 def json_schema_response_format(self) -> dict[str, Any]: 196 output_schema = self.task().output_schema() 197 return { 198 "response_format": { 199 "type": "json_schema", 200 "json_schema": { 201 "name": "task_response", 202 "schema": output_schema, 203 }, 204 } 205 } 206 207 def tool_call_params(self, strict: bool) -> dict[str, Any]: 208 # Add additional_properties: false to the schema (OpenAI requires this for some models) 209 output_schema = self.task().output_schema() 210 if not isinstance(output_schema, dict): 211 raise ValueError( 212 "Invalid output schema for this task. Can not use tool calls." 213 ) 214 output_schema["additionalProperties"] = False 215 216 function_params = { 217 "name": "task_response", 218 "parameters": output_schema, 219 } 220 # This should be on, but we allow setting function_calling_weak for APIs that don't support it. 221 if strict: 222 function_params["strict"] = True 223 224 return { 225 "tools": [ 226 { 227 "type": "function", 228 "function": function_params, 229 } 230 ], 231 "tool_choice": { 232 "type": "function", 233 "function": {"name": "task_response"}, 234 }, 235 } 236 237 def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]: 238 # TODO P1: Don't love having this logic here. But it's a usability improvement 239 # so better to keep it than exclude it. Should figure out how I want to isolate 240 # this sort of logic so it's config driven and can be overridden 241 242 extra_body = {} 243 provider_options = {} 244 245 if provider.thinking_level is not None: 246 extra_body["reasoning_effort"] = provider.thinking_level 247 248 if provider.require_openrouter_reasoning: 249 # https://openrouter.ai/docs/use-cases/reasoning-tokens 250 extra_body["reasoning"] = { 251 "exclude": False, 252 } 253 254 if provider.anthropic_extended_thinking: 255 extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000} 256 257 if provider.r1_openrouter_options: 258 # Require providers that support the reasoning parameter 259 provider_options["require_parameters"] = True 260 # Prefer R1 providers with reasonable perf/quants 261 provider_options["order"] = ["Fireworks", "Together"] 262 # R1 providers with unreasonable quants 263 provider_options["ignore"] = ["DeepInfra"] 264 265 # Only set of this request is to get logprobs. 266 if ( 267 provider.logprobs_openrouter_options 268 and self.base_adapter_config.top_logprobs is not None 269 ): 270 # Don't let OpenRouter choose a provider that doesn't support logprobs. 271 provider_options["require_parameters"] = True 272 # DeepInfra silently fails to return logprobs consistently. 273 provider_options["ignore"] = ["DeepInfra"] 274 275 if provider.openrouter_skip_required_parameters: 276 # Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params. 277 provider_options["require_parameters"] = False 278 279 if len(provider_options) > 0: 280 extra_body["provider"] = provider_options 281 282 return extra_body 283 284 def litellm_model_id(self) -> str: 285 # The model ID is an interesting combination of format and url endpoint. 286 # It specifics the provider URL/host, but this is overridden if you manually set an api url 287 288 if self._litellm_model_id: 289 return self._litellm_model_id 290 291 provider = self.model_provider() 292 if not provider.model_id: 293 raise ValueError("Model ID is required for OpenAI compatible models") 294 295 litellm_provider_name: str | None = None 296 is_custom = False 297 match provider.name: 298 case ModelProviderName.openrouter: 299 litellm_provider_name = "openrouter" 300 case ModelProviderName.openai: 301 litellm_provider_name = "openai" 302 case ModelProviderName.groq: 303 litellm_provider_name = "groq" 304 case ModelProviderName.anthropic: 305 litellm_provider_name = "anthropic" 306 case ModelProviderName.ollama: 307 # We don't let litellm use the Ollama API and muck with our requests. We use Ollama's OpenAI compatible API. 308 # This is because we're setting detailed features like response_format=json_schema and want lower level control. 309 is_custom = True 310 case ModelProviderName.gemini_api: 311 litellm_provider_name = "gemini" 312 case ModelProviderName.fireworks_ai: 313 litellm_provider_name = "fireworks_ai" 314 case ModelProviderName.amazon_bedrock: 315 litellm_provider_name = "bedrock" 316 case ModelProviderName.azure_openai: 317 litellm_provider_name = "azure" 318 case ModelProviderName.huggingface: 319 litellm_provider_name = "huggingface" 320 case ModelProviderName.vertex: 321 litellm_provider_name = "vertex_ai" 322 case ModelProviderName.together_ai: 323 litellm_provider_name = "together_ai" 324 case ModelProviderName.openai_compatible: 325 is_custom = True 326 case ModelProviderName.kiln_custom_registry: 327 is_custom = True 328 case ModelProviderName.kiln_fine_tune: 329 is_custom = True 330 case _: 331 raise_exhaustive_enum_error(provider.name) 332 333 if is_custom: 334 if self._api_base is None: 335 raise ValueError( 336 "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)" 337 ) 338 # Use openai as it's only used for format, not url 339 litellm_provider_name = "openai" 340 341 # Sholdn't be possible but keep type checker happy 342 if litellm_provider_name is None: 343 raise ValueError( 344 f"Provider name could not lookup valid litellm provider ID {provider.model_id}" 345 ) 346 347 self._litellm_model_id = litellm_provider_name + "/" + provider.model_id 348 return self._litellm_model_id 349 350 async def build_completion_kwargs( 351 self, 352 provider: KilnModelProvider, 353 messages: list[dict[str, Any]], 354 top_logprobs: int | None, 355 skip_response_format: bool = False, 356 ) -> dict[str, Any]: 357 extra_body = self.build_extra_body(provider) 358 359 # Merge all parameters into a single kwargs dict for litellm 360 completion_kwargs = { 361 "model": self.litellm_model_id(), 362 "messages": messages, 363 "api_base": self._api_base, 364 "headers": self._headers, 365 "temperature": self.run_config.temperature, 366 "top_p": self.run_config.top_p, 367 # This drops params that are not supported by the model. Only openai params like top_p, temperature -- not litellm params like model, etc. 368 # Not all models and providers support all openai params (for example, o3 doesn't support top_p) 369 # Better to ignore them than to fail the model call. 370 # https://docs.litellm.ai/docs/completion/input 371 "drop_params": True, 372 **extra_body, 373 **self._additional_body_options, 374 } 375 376 if not skip_response_format: 377 # Response format: json_schema, json_instructions, json_mode, function_calling, etc 378 response_format_options = await self.response_format_options() 379 completion_kwargs.update(response_format_options) 380 381 if top_logprobs is not None: 382 completion_kwargs["logprobs"] = True 383 completion_kwargs["top_logprobs"] = top_logprobs 384 385 return completion_kwargs 386 387 def usage_from_response(self, response: ModelResponse) -> Usage | None: 388 litellm_usage = response.get("usage", None) 389 cost = response._hidden_params.get("response_cost", None) 390 if not litellm_usage and not cost: 391 return None 392 393 usage = Usage() 394 395 if litellm_usage and isinstance(litellm_usage, LiteLlmUsage): 396 usage.input_tokens = litellm_usage.get("prompt_tokens", None) 397 usage.output_tokens = litellm_usage.get("completion_tokens", None) 398 usage.total_tokens = litellm_usage.get("total_tokens", None) 399 else: 400 logger.warning( 401 f"Unexpected usage format from litellm: {litellm_usage}. Expected Usage object, got {type(litellm_usage)}" 402 ) 403 404 if isinstance(cost, float): 405 usage.cost = cost 406 elif cost is not None: 407 # None is allowed, but no other types are expected 408 logger.warning( 409 f"Unexpected cost format from litellm: {cost}. Expected float, got {type(cost)}" 410 ) 411 412 return usage
logger =
<Logger kiln_ai.adapters.model_adapters.litellm_adapter (WARNING)>
28class LiteLlmAdapter(BaseAdapter): 29 def __init__( 30 self, 31 config: LiteLlmConfig, 32 kiln_task: datamodel.Task, 33 base_adapter_config: AdapterConfig | None = None, 34 ): 35 self.config = config 36 self._additional_body_options = config.additional_body_options 37 self._api_base = config.base_url 38 self._headers = config.default_headers 39 self._litellm_model_id: str | None = None 40 41 # Create a RunConfig, adding the task to the RunConfigProperties 42 run_config = run_config_from_run_config_properties( 43 task=kiln_task, 44 run_config_properties=config.run_config_properties, 45 ) 46 47 super().__init__( 48 run_config=run_config, 49 config=base_adapter_config, 50 ) 51 52 async def _run(self, input: Dict | str) -> tuple[RunOutput, Usage | None]: 53 provider = self.model_provider() 54 if not provider.model_id: 55 raise ValueError("Model ID is required for OpenAI compatible models") 56 57 chat_formatter = self.build_chat_formatter(input) 58 59 prior_output = None 60 prior_message = None 61 response = None 62 turns = 0 63 while True: 64 turns += 1 65 if turns > 10: 66 raise RuntimeError( 67 "Too many turns. Stopping iteration to avoid using too many tokens." 68 ) 69 70 turn = chat_formatter.next_turn(prior_output) 71 if turn is None: 72 break 73 74 skip_response_format = not turn.final_call 75 all_messages = chat_formatter.message_dicts() 76 completion_kwargs = await self.build_completion_kwargs( 77 provider, 78 all_messages, 79 self.base_adapter_config.top_logprobs if turn.final_call else None, 80 skip_response_format, 81 ) 82 response = await litellm.acompletion(**completion_kwargs) 83 if ( 84 not isinstance(response, ModelResponse) 85 or not response.choices 86 or len(response.choices) == 0 87 or not isinstance(response.choices[0], Choices) 88 ): 89 raise RuntimeError( 90 f"Expected ModelResponse with Choices, got {type(response)}." 91 ) 92 prior_message = response.choices[0].message 93 prior_output = prior_message.content 94 95 # Fallback: Use args of first tool call to task_response if it exists 96 if ( 97 not prior_output 98 and hasattr(prior_message, "tool_calls") 99 and prior_message.tool_calls 100 ): 101 tool_call = next( 102 ( 103 tool_call 104 for tool_call in prior_message.tool_calls 105 if tool_call.function.name == "task_response" 106 ), 107 None, 108 ) 109 if tool_call: 110 prior_output = tool_call.function.arguments 111 112 if not prior_output: 113 raise RuntimeError("No output returned from model") 114 115 if response is None or prior_message is None: 116 raise RuntimeError("No response returned from model") 117 118 intermediate_outputs = chat_formatter.intermediate_outputs() 119 120 logprobs = ( 121 response.choices[0].logprobs 122 if hasattr(response.choices[0], "logprobs") 123 and isinstance(response.choices[0].logprobs, ChoiceLogprobs) 124 else None 125 ) 126 127 # Check logprobs worked, if requested 128 if self.base_adapter_config.top_logprobs is not None and logprobs is None: 129 raise RuntimeError("Logprobs were required, but no logprobs were returned.") 130 131 # Save reasoning if it exists and was parsed by LiteLLM (or openrouter, or anyone upstream) 132 if ( 133 prior_message is not None 134 and hasattr(prior_message, "reasoning_content") 135 and prior_message.reasoning_content 136 and len(prior_message.reasoning_content.strip()) > 0 137 ): 138 intermediate_outputs["reasoning"] = prior_message.reasoning_content.strip() 139 140 # the string content of the response 141 response_content = prior_output 142 143 if not isinstance(response_content, str): 144 raise RuntimeError(f"response is not a string: {response_content}") 145 146 return RunOutput( 147 output=response_content, 148 intermediate_outputs=intermediate_outputs, 149 output_logprobs=logprobs, 150 ), self.usage_from_response(response) 151 152 def adapter_name(self) -> str: 153 return "kiln_openai_compatible_adapter" 154 155 async def response_format_options(self) -> dict[str, Any]: 156 # Unstructured if task isn't structured 157 if not self.has_structured_output(): 158 return {} 159 160 structured_output_mode = self.run_config.structured_output_mode 161 162 match structured_output_mode: 163 case StructuredOutputMode.json_mode: 164 return {"response_format": {"type": "json_object"}} 165 case StructuredOutputMode.json_schema: 166 return self.json_schema_response_format() 167 case StructuredOutputMode.function_calling_weak: 168 return self.tool_call_params(strict=False) 169 case StructuredOutputMode.function_calling: 170 return self.tool_call_params(strict=True) 171 case StructuredOutputMode.json_instructions: 172 # JSON instructions dynamically injected in prompt, not the API response format. Do not ask for json_object (see option below). 173 return {} 174 case StructuredOutputMode.json_custom_instructions: 175 # JSON instructions statically injected in system prompt, not the API response format. Do not ask for json_object (see option above). 176 return {} 177 case StructuredOutputMode.json_instruction_and_object: 178 # We set response_format to json_object and also set json instructions in the prompt 179 return {"response_format": {"type": "json_object"}} 180 case StructuredOutputMode.default: 181 provider_name = self.run_config.model_provider_name 182 if provider_name == ModelProviderName.ollama: 183 # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs 184 return self.json_schema_response_format() 185 else: 186 # Default to function calling -- it's older than the other modes. Higher compatibility. 187 # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI. 188 strict = provider_name == ModelProviderName.openai 189 return self.tool_call_params(strict=strict) 190 case StructuredOutputMode.unknown: 191 # See above, but this case should never happen. 192 raise ValueError("Structured output mode is unknown.") 193 case _: 194 raise_exhaustive_enum_error(structured_output_mode) 195 196 def json_schema_response_format(self) -> dict[str, Any]: 197 output_schema = self.task().output_schema() 198 return { 199 "response_format": { 200 "type": "json_schema", 201 "json_schema": { 202 "name": "task_response", 203 "schema": output_schema, 204 }, 205 } 206 } 207 208 def tool_call_params(self, strict: bool) -> dict[str, Any]: 209 # Add additional_properties: false to the schema (OpenAI requires this for some models) 210 output_schema = self.task().output_schema() 211 if not isinstance(output_schema, dict): 212 raise ValueError( 213 "Invalid output schema for this task. Can not use tool calls." 214 ) 215 output_schema["additionalProperties"] = False 216 217 function_params = { 218 "name": "task_response", 219 "parameters": output_schema, 220 } 221 # This should be on, but we allow setting function_calling_weak for APIs that don't support it. 222 if strict: 223 function_params["strict"] = True 224 225 return { 226 "tools": [ 227 { 228 "type": "function", 229 "function": function_params, 230 } 231 ], 232 "tool_choice": { 233 "type": "function", 234 "function": {"name": "task_response"}, 235 }, 236 } 237 238 def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]: 239 # TODO P1: Don't love having this logic here. But it's a usability improvement 240 # so better to keep it than exclude it. Should figure out how I want to isolate 241 # this sort of logic so it's config driven and can be overridden 242 243 extra_body = {} 244 provider_options = {} 245 246 if provider.thinking_level is not None: 247 extra_body["reasoning_effort"] = provider.thinking_level 248 249 if provider.require_openrouter_reasoning: 250 # https://openrouter.ai/docs/use-cases/reasoning-tokens 251 extra_body["reasoning"] = { 252 "exclude": False, 253 } 254 255 if provider.anthropic_extended_thinking: 256 extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000} 257 258 if provider.r1_openrouter_options: 259 # Require providers that support the reasoning parameter 260 provider_options["require_parameters"] = True 261 # Prefer R1 providers with reasonable perf/quants 262 provider_options["order"] = ["Fireworks", "Together"] 263 # R1 providers with unreasonable quants 264 provider_options["ignore"] = ["DeepInfra"] 265 266 # Only set of this request is to get logprobs. 267 if ( 268 provider.logprobs_openrouter_options 269 and self.base_adapter_config.top_logprobs is not None 270 ): 271 # Don't let OpenRouter choose a provider that doesn't support logprobs. 272 provider_options["require_parameters"] = True 273 # DeepInfra silently fails to return logprobs consistently. 274 provider_options["ignore"] = ["DeepInfra"] 275 276 if provider.openrouter_skip_required_parameters: 277 # Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params. 278 provider_options["require_parameters"] = False 279 280 if len(provider_options) > 0: 281 extra_body["provider"] = provider_options 282 283 return extra_body 284 285 def litellm_model_id(self) -> str: 286 # The model ID is an interesting combination of format and url endpoint. 287 # It specifics the provider URL/host, but this is overridden if you manually set an api url 288 289 if self._litellm_model_id: 290 return self._litellm_model_id 291 292 provider = self.model_provider() 293 if not provider.model_id: 294 raise ValueError("Model ID is required for OpenAI compatible models") 295 296 litellm_provider_name: str | None = None 297 is_custom = False 298 match provider.name: 299 case ModelProviderName.openrouter: 300 litellm_provider_name = "openrouter" 301 case ModelProviderName.openai: 302 litellm_provider_name = "openai" 303 case ModelProviderName.groq: 304 litellm_provider_name = "groq" 305 case ModelProviderName.anthropic: 306 litellm_provider_name = "anthropic" 307 case ModelProviderName.ollama: 308 # We don't let litellm use the Ollama API and muck with our requests. We use Ollama's OpenAI compatible API. 309 # This is because we're setting detailed features like response_format=json_schema and want lower level control. 310 is_custom = True 311 case ModelProviderName.gemini_api: 312 litellm_provider_name = "gemini" 313 case ModelProviderName.fireworks_ai: 314 litellm_provider_name = "fireworks_ai" 315 case ModelProviderName.amazon_bedrock: 316 litellm_provider_name = "bedrock" 317 case ModelProviderName.azure_openai: 318 litellm_provider_name = "azure" 319 case ModelProviderName.huggingface: 320 litellm_provider_name = "huggingface" 321 case ModelProviderName.vertex: 322 litellm_provider_name = "vertex_ai" 323 case ModelProviderName.together_ai: 324 litellm_provider_name = "together_ai" 325 case ModelProviderName.openai_compatible: 326 is_custom = True 327 case ModelProviderName.kiln_custom_registry: 328 is_custom = True 329 case ModelProviderName.kiln_fine_tune: 330 is_custom = True 331 case _: 332 raise_exhaustive_enum_error(provider.name) 333 334 if is_custom: 335 if self._api_base is None: 336 raise ValueError( 337 "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)" 338 ) 339 # Use openai as it's only used for format, not url 340 litellm_provider_name = "openai" 341 342 # Sholdn't be possible but keep type checker happy 343 if litellm_provider_name is None: 344 raise ValueError( 345 f"Provider name could not lookup valid litellm provider ID {provider.model_id}" 346 ) 347 348 self._litellm_model_id = litellm_provider_name + "/" + provider.model_id 349 return self._litellm_model_id 350 351 async def build_completion_kwargs( 352 self, 353 provider: KilnModelProvider, 354 messages: list[dict[str, Any]], 355 top_logprobs: int | None, 356 skip_response_format: bool = False, 357 ) -> dict[str, Any]: 358 extra_body = self.build_extra_body(provider) 359 360 # Merge all parameters into a single kwargs dict for litellm 361 completion_kwargs = { 362 "model": self.litellm_model_id(), 363 "messages": messages, 364 "api_base": self._api_base, 365 "headers": self._headers, 366 "temperature": self.run_config.temperature, 367 "top_p": self.run_config.top_p, 368 # This drops params that are not supported by the model. Only openai params like top_p, temperature -- not litellm params like model, etc. 369 # Not all models and providers support all openai params (for example, o3 doesn't support top_p) 370 # Better to ignore them than to fail the model call. 371 # https://docs.litellm.ai/docs/completion/input 372 "drop_params": True, 373 **extra_body, 374 **self._additional_body_options, 375 } 376 377 if not skip_response_format: 378 # Response format: json_schema, json_instructions, json_mode, function_calling, etc 379 response_format_options = await self.response_format_options() 380 completion_kwargs.update(response_format_options) 381 382 if top_logprobs is not None: 383 completion_kwargs["logprobs"] = True 384 completion_kwargs["top_logprobs"] = top_logprobs 385 386 return completion_kwargs 387 388 def usage_from_response(self, response: ModelResponse) -> Usage | None: 389 litellm_usage = response.get("usage", None) 390 cost = response._hidden_params.get("response_cost", None) 391 if not litellm_usage and not cost: 392 return None 393 394 usage = Usage() 395 396 if litellm_usage and isinstance(litellm_usage, LiteLlmUsage): 397 usage.input_tokens = litellm_usage.get("prompt_tokens", None) 398 usage.output_tokens = litellm_usage.get("completion_tokens", None) 399 usage.total_tokens = litellm_usage.get("total_tokens", None) 400 else: 401 logger.warning( 402 f"Unexpected usage format from litellm: {litellm_usage}. Expected Usage object, got {type(litellm_usage)}" 403 ) 404 405 if isinstance(cost, float): 406 usage.cost = cost 407 elif cost is not None: 408 # None is allowed, but no other types are expected 409 logger.warning( 410 f"Unexpected cost format from litellm: {cost}. Expected float, got {type(cost)}" 411 ) 412 413 return usage
Base class for AI model adapters that handle task execution.
This abstract class provides the foundation for implementing model-specific adapters that can process tasks with structured or unstructured inputs/outputs. It handles input/output validation, prompt building, and run tracking.
Attributes: prompt_builder (BasePromptBuilder): Builder for constructing prompts for the model kiln_task (Task): The task configuration and metadata output_schema (dict | None): JSON schema for validating structured outputs input_schema (dict | None): JSON schema for validating structured inputs
LiteLlmAdapter( config: kiln_ai.adapters.model_adapters.litellm_config.LiteLlmConfig, kiln_task: kiln_ai.datamodel.Task, base_adapter_config: kiln_ai.adapters.model_adapters.base_adapter.AdapterConfig | None = None)
29 def __init__( 30 self, 31 config: LiteLlmConfig, 32 kiln_task: datamodel.Task, 33 base_adapter_config: AdapterConfig | None = None, 34 ): 35 self.config = config 36 self._additional_body_options = config.additional_body_options 37 self._api_base = config.base_url 38 self._headers = config.default_headers 39 self._litellm_model_id: str | None = None 40 41 # Create a RunConfig, adding the task to the RunConfigProperties 42 run_config = run_config_from_run_config_properties( 43 task=kiln_task, 44 run_config_properties=config.run_config_properties, 45 ) 46 47 super().__init__( 48 run_config=run_config, 49 config=base_adapter_config, 50 )
async def
response_format_options(self) -> dict[str, typing.Any]:
155 async def response_format_options(self) -> dict[str, Any]: 156 # Unstructured if task isn't structured 157 if not self.has_structured_output(): 158 return {} 159 160 structured_output_mode = self.run_config.structured_output_mode 161 162 match structured_output_mode: 163 case StructuredOutputMode.json_mode: 164 return {"response_format": {"type": "json_object"}} 165 case StructuredOutputMode.json_schema: 166 return self.json_schema_response_format() 167 case StructuredOutputMode.function_calling_weak: 168 return self.tool_call_params(strict=False) 169 case StructuredOutputMode.function_calling: 170 return self.tool_call_params(strict=True) 171 case StructuredOutputMode.json_instructions: 172 # JSON instructions dynamically injected in prompt, not the API response format. Do not ask for json_object (see option below). 173 return {} 174 case StructuredOutputMode.json_custom_instructions: 175 # JSON instructions statically injected in system prompt, not the API response format. Do not ask for json_object (see option above). 176 return {} 177 case StructuredOutputMode.json_instruction_and_object: 178 # We set response_format to json_object and also set json instructions in the prompt 179 return {"response_format": {"type": "json_object"}} 180 case StructuredOutputMode.default: 181 provider_name = self.run_config.model_provider_name 182 if provider_name == ModelProviderName.ollama: 183 # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs 184 return self.json_schema_response_format() 185 else: 186 # Default to function calling -- it's older than the other modes. Higher compatibility. 187 # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI. 188 strict = provider_name == ModelProviderName.openai 189 return self.tool_call_params(strict=strict) 190 case StructuredOutputMode.unknown: 191 # See above, but this case should never happen. 192 raise ValueError("Structured output mode is unknown.") 193 case _: 194 raise_exhaustive_enum_error(structured_output_mode)
def
tool_call_params(self, strict: bool) -> dict[str, typing.Any]:
208 def tool_call_params(self, strict: bool) -> dict[str, Any]: 209 # Add additional_properties: false to the schema (OpenAI requires this for some models) 210 output_schema = self.task().output_schema() 211 if not isinstance(output_schema, dict): 212 raise ValueError( 213 "Invalid output schema for this task. Can not use tool calls." 214 ) 215 output_schema["additionalProperties"] = False 216 217 function_params = { 218 "name": "task_response", 219 "parameters": output_schema, 220 } 221 # This should be on, but we allow setting function_calling_weak for APIs that don't support it. 222 if strict: 223 function_params["strict"] = True 224 225 return { 226 "tools": [ 227 { 228 "type": "function", 229 "function": function_params, 230 } 231 ], 232 "tool_choice": { 233 "type": "function", 234 "function": {"name": "task_response"}, 235 }, 236 }
def
build_extra_body( self, provider: kiln_ai.adapters.ml_model_list.KilnModelProvider) -> dict[str, typing.Any]:
238 def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]: 239 # TODO P1: Don't love having this logic here. But it's a usability improvement 240 # so better to keep it than exclude it. Should figure out how I want to isolate 241 # this sort of logic so it's config driven and can be overridden 242 243 extra_body = {} 244 provider_options = {} 245 246 if provider.thinking_level is not None: 247 extra_body["reasoning_effort"] = provider.thinking_level 248 249 if provider.require_openrouter_reasoning: 250 # https://openrouter.ai/docs/use-cases/reasoning-tokens 251 extra_body["reasoning"] = { 252 "exclude": False, 253 } 254 255 if provider.anthropic_extended_thinking: 256 extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000} 257 258 if provider.r1_openrouter_options: 259 # Require providers that support the reasoning parameter 260 provider_options["require_parameters"] = True 261 # Prefer R1 providers with reasonable perf/quants 262 provider_options["order"] = ["Fireworks", "Together"] 263 # R1 providers with unreasonable quants 264 provider_options["ignore"] = ["DeepInfra"] 265 266 # Only set of this request is to get logprobs. 267 if ( 268 provider.logprobs_openrouter_options 269 and self.base_adapter_config.top_logprobs is not None 270 ): 271 # Don't let OpenRouter choose a provider that doesn't support logprobs. 272 provider_options["require_parameters"] = True 273 # DeepInfra silently fails to return logprobs consistently. 274 provider_options["ignore"] = ["DeepInfra"] 275 276 if provider.openrouter_skip_required_parameters: 277 # Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params. 278 provider_options["require_parameters"] = False 279 280 if len(provider_options) > 0: 281 extra_body["provider"] = provider_options 282 283 return extra_body
def
litellm_model_id(self) -> str:
285 def litellm_model_id(self) -> str: 286 # The model ID is an interesting combination of format and url endpoint. 287 # It specifics the provider URL/host, but this is overridden if you manually set an api url 288 289 if self._litellm_model_id: 290 return self._litellm_model_id 291 292 provider = self.model_provider() 293 if not provider.model_id: 294 raise ValueError("Model ID is required for OpenAI compatible models") 295 296 litellm_provider_name: str | None = None 297 is_custom = False 298 match provider.name: 299 case ModelProviderName.openrouter: 300 litellm_provider_name = "openrouter" 301 case ModelProviderName.openai: 302 litellm_provider_name = "openai" 303 case ModelProviderName.groq: 304 litellm_provider_name = "groq" 305 case ModelProviderName.anthropic: 306 litellm_provider_name = "anthropic" 307 case ModelProviderName.ollama: 308 # We don't let litellm use the Ollama API and muck with our requests. We use Ollama's OpenAI compatible API. 309 # This is because we're setting detailed features like response_format=json_schema and want lower level control. 310 is_custom = True 311 case ModelProviderName.gemini_api: 312 litellm_provider_name = "gemini" 313 case ModelProviderName.fireworks_ai: 314 litellm_provider_name = "fireworks_ai" 315 case ModelProviderName.amazon_bedrock: 316 litellm_provider_name = "bedrock" 317 case ModelProviderName.azure_openai: 318 litellm_provider_name = "azure" 319 case ModelProviderName.huggingface: 320 litellm_provider_name = "huggingface" 321 case ModelProviderName.vertex: 322 litellm_provider_name = "vertex_ai" 323 case ModelProviderName.together_ai: 324 litellm_provider_name = "together_ai" 325 case ModelProviderName.openai_compatible: 326 is_custom = True 327 case ModelProviderName.kiln_custom_registry: 328 is_custom = True 329 case ModelProviderName.kiln_fine_tune: 330 is_custom = True 331 case _: 332 raise_exhaustive_enum_error(provider.name) 333 334 if is_custom: 335 if self._api_base is None: 336 raise ValueError( 337 "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)" 338 ) 339 # Use openai as it's only used for format, not url 340 litellm_provider_name = "openai" 341 342 # Sholdn't be possible but keep type checker happy 343 if litellm_provider_name is None: 344 raise ValueError( 345 f"Provider name could not lookup valid litellm provider ID {provider.model_id}" 346 ) 347 348 self._litellm_model_id = litellm_provider_name + "/" + provider.model_id 349 return self._litellm_model_id
async def
build_completion_kwargs( self, provider: kiln_ai.adapters.ml_model_list.KilnModelProvider, messages: list[dict[str, typing.Any]], top_logprobs: int | None, skip_response_format: bool = False) -> dict[str, typing.Any]:
351 async def build_completion_kwargs( 352 self, 353 provider: KilnModelProvider, 354 messages: list[dict[str, Any]], 355 top_logprobs: int | None, 356 skip_response_format: bool = False, 357 ) -> dict[str, Any]: 358 extra_body = self.build_extra_body(provider) 359 360 # Merge all parameters into a single kwargs dict for litellm 361 completion_kwargs = { 362 "model": self.litellm_model_id(), 363 "messages": messages, 364 "api_base": self._api_base, 365 "headers": self._headers, 366 "temperature": self.run_config.temperature, 367 "top_p": self.run_config.top_p, 368 # This drops params that are not supported by the model. Only openai params like top_p, temperature -- not litellm params like model, etc. 369 # Not all models and providers support all openai params (for example, o3 doesn't support top_p) 370 # Better to ignore them than to fail the model call. 371 # https://docs.litellm.ai/docs/completion/input 372 "drop_params": True, 373 **extra_body, 374 **self._additional_body_options, 375 } 376 377 if not skip_response_format: 378 # Response format: json_schema, json_instructions, json_mode, function_calling, etc 379 response_format_options = await self.response_format_options() 380 completion_kwargs.update(response_format_options) 381 382 if top_logprobs is not None: 383 completion_kwargs["logprobs"] = True 384 completion_kwargs["top_logprobs"] = top_logprobs 385 386 return completion_kwargs
def
usage_from_response( self, response: litellm.types.utils.ModelResponse) -> kiln_ai.datamodel.Usage | None:
388 def usage_from_response(self, response: ModelResponse) -> Usage | None: 389 litellm_usage = response.get("usage", None) 390 cost = response._hidden_params.get("response_cost", None) 391 if not litellm_usage and not cost: 392 return None 393 394 usage = Usage() 395 396 if litellm_usage and isinstance(litellm_usage, LiteLlmUsage): 397 usage.input_tokens = litellm_usage.get("prompt_tokens", None) 398 usage.output_tokens = litellm_usage.get("completion_tokens", None) 399 usage.total_tokens = litellm_usage.get("total_tokens", None) 400 else: 401 logger.warning( 402 f"Unexpected usage format from litellm: {litellm_usage}. Expected Usage object, got {type(litellm_usage)}" 403 ) 404 405 if isinstance(cost, float): 406 usage.cost = cost 407 elif cost is not None: 408 # None is allowed, but no other types are expected 409 logger.warning( 410 f"Unexpected cost format from litellm: {cost}. Expected float, got {type(cost)}" 411 ) 412 413 return usage
Inherited Members
- kiln_ai.adapters.model_adapters.base_adapter.BaseAdapter
- run_config
- prompt_builder
- output_schema
- input_schema
- base_adapter_config
- task
- model_provider
- invoke
- invoke_returning_run_output
- has_structured_output
- build_prompt
- build_chat_formatter
- generate_run
- update_run_config_unknown_structured_output_mode