kiln_ai.adapters.model_adapters.litellm_adapter
1import logging 2from typing import Any, Dict 3 4import litellm 5from litellm.types.utils import ChoiceLogprobs, Choices, ModelResponse 6from litellm.types.utils import Usage as LiteLlmUsage 7 8import kiln_ai.datamodel as datamodel 9from kiln_ai.adapters.ml_model_list import ( 10 KilnModelProvider, 11 ModelProviderName, 12 StructuredOutputMode, 13) 14from kiln_ai.adapters.model_adapters.base_adapter import ( 15 COT_FINAL_ANSWER_PROMPT, 16 AdapterConfig, 17 BaseAdapter, 18 RunOutput, 19 Usage, 20) 21from kiln_ai.adapters.model_adapters.litellm_config import LiteLlmConfig 22from kiln_ai.datamodel import PromptGenerators, PromptId 23from kiln_ai.datamodel.task import RunConfig 24from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error 25 26logger = logging.getLogger(__name__) 27 28 29class LiteLlmAdapter(BaseAdapter): 30 def __init__( 31 self, 32 config: LiteLlmConfig, 33 kiln_task: datamodel.Task, 34 prompt_id: PromptId | None = None, 35 base_adapter_config: AdapterConfig | None = None, 36 ): 37 self.config = config 38 self._additional_body_options = config.additional_body_options 39 self._api_base = config.base_url 40 self._headers = config.default_headers 41 self._litellm_model_id: str | None = None 42 43 run_config = RunConfig( 44 task=kiln_task, 45 model_name=config.model_name, 46 model_provider_name=config.provider_name, 47 prompt_id=prompt_id or PromptGenerators.SIMPLE, 48 ) 49 50 super().__init__( 51 run_config=run_config, 52 config=base_adapter_config, 53 ) 54 55 async def _run(self, input: Dict | str) -> tuple[RunOutput, Usage | None]: 56 provider = self.model_provider() 57 if not provider.model_id: 58 raise ValueError("Model ID is required for OpenAI compatible models") 59 60 intermediate_outputs: dict[str, str] = {} 61 prompt = self.build_prompt() 62 user_msg = self.prompt_builder.build_user_message(input) 63 messages = [ 64 {"role": "system", "content": prompt}, 65 {"role": "user", "content": user_msg}, 66 ] 67 68 run_strategy, cot_prompt = self.run_strategy() 69 70 if run_strategy == "cot_as_message": 71 # Used for reasoning-capable models that can output thinking and structured format 72 if not cot_prompt: 73 raise ValueError("cot_prompt is required for cot_as_message strategy") 74 messages.append({"role": "system", "content": cot_prompt}) 75 elif run_strategy == "cot_two_call": 76 if not cot_prompt: 77 raise ValueError("cot_prompt is required for cot_two_call strategy") 78 messages.append({"role": "system", "content": cot_prompt}) 79 80 # First call for chain of thought 81 # No response format as this request is for "thinking" in plain text 82 # No logprobs as only needed for final answer 83 completion_kwargs = await self.build_completion_kwargs( 84 provider, messages, None, skip_response_format=True 85 ) 86 cot_response = await litellm.acompletion(**completion_kwargs) 87 if ( 88 not isinstance(cot_response, ModelResponse) 89 or not cot_response.choices 90 or len(cot_response.choices) == 0 91 or not isinstance(cot_response.choices[0], Choices) 92 ): 93 raise RuntimeError( 94 f"Expected ModelResponse with Choices, got {type(cot_response)}." 95 ) 96 cot_content = cot_response.choices[0].message.content 97 if cot_content is not None: 98 intermediate_outputs["chain_of_thought"] = cot_content 99 100 messages.extend( 101 [ 102 {"role": "assistant", "content": cot_content or ""}, 103 {"role": "user", "content": COT_FINAL_ANSWER_PROMPT}, 104 ] 105 ) 106 107 # Make the API call using litellm 108 completion_kwargs = await self.build_completion_kwargs( 109 provider, messages, self.base_adapter_config.top_logprobs 110 ) 111 response = await litellm.acompletion(**completion_kwargs) 112 113 if not isinstance(response, ModelResponse): 114 raise RuntimeError(f"Expected ModelResponse, got {type(response)}.") 115 116 # Maybe remove this? There is no error attribute on the response object. 117 # # Keeping in typesafe way as we added it for a reason, but should investigate what that was and if it still applies. 118 if hasattr(response, "error") and response.__getattribute__("error"): 119 raise RuntimeError( 120 f"LLM API returned an error: {response.__getattribute__('error')}" 121 ) 122 123 if ( 124 not response.choices 125 or len(response.choices) == 0 126 or not isinstance(response.choices[0], Choices) 127 ): 128 raise RuntimeError( 129 "No message content returned in the response from LLM API" 130 ) 131 132 message = response.choices[0].message 133 logprobs = ( 134 response.choices[0].logprobs 135 if hasattr(response.choices[0], "logprobs") 136 and isinstance(response.choices[0].logprobs, ChoiceLogprobs) 137 else None 138 ) 139 140 # Check logprobs worked, if requested 141 if self.base_adapter_config.top_logprobs is not None and logprobs is None: 142 raise RuntimeError("Logprobs were required, but no logprobs were returned.") 143 144 # Save reasoning if it exists and was parsed by LiteLLM (or openrouter, or anyone upstream) 145 if ( 146 hasattr(message, "reasoning_content") 147 and message.reasoning_content 148 and len(message.reasoning_content.strip()) > 0 149 ): 150 intermediate_outputs["reasoning"] = message.reasoning_content.strip() 151 152 # the string content of the response 153 response_content = message.content 154 155 # Fallback: Use args of first tool call to task_response if it exists 156 if ( 157 not response_content 158 and hasattr(message, "tool_calls") 159 and message.tool_calls 160 ): 161 tool_call = next( 162 ( 163 tool_call 164 for tool_call in message.tool_calls 165 if tool_call.function.name == "task_response" 166 ), 167 None, 168 ) 169 if tool_call: 170 response_content = tool_call.function.arguments 171 172 if not isinstance(response_content, str): 173 raise RuntimeError(f"response is not a string: {response_content}") 174 175 return RunOutput( 176 output=response_content, 177 intermediate_outputs=intermediate_outputs, 178 output_logprobs=logprobs, 179 ), self.usage_from_response(response) 180 181 def adapter_name(self) -> str: 182 return "kiln_openai_compatible_adapter" 183 184 async def response_format_options(self) -> dict[str, Any]: 185 # Unstructured if task isn't structured 186 if not self.has_structured_output(): 187 return {} 188 189 provider = self.model_provider() 190 match provider.structured_output_mode: 191 case StructuredOutputMode.json_mode: 192 return {"response_format": {"type": "json_object"}} 193 case StructuredOutputMode.json_schema: 194 return self.json_schema_response_format() 195 case StructuredOutputMode.function_calling_weak: 196 return self.tool_call_params(strict=False) 197 case StructuredOutputMode.function_calling: 198 return self.tool_call_params(strict=True) 199 case StructuredOutputMode.json_instructions: 200 # JSON instructions dynamically injected in prompt, not the API response format. Do not ask for json_object (see option below). 201 return {} 202 case StructuredOutputMode.json_custom_instructions: 203 # JSON instructions statically injected in system prompt, not the API response format. Do not ask for json_object (see option above). 204 return {} 205 case StructuredOutputMode.json_instruction_and_object: 206 # We set response_format to json_object and also set json instructions in the prompt 207 return {"response_format": {"type": "json_object"}} 208 case StructuredOutputMode.default: 209 if provider.name == ModelProviderName.ollama: 210 # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs 211 return self.json_schema_response_format() 212 else: 213 # Default to function calling -- it's older than the other modes. Higher compatibility. 214 # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI. 215 strict = provider.name == ModelProviderName.openai 216 return self.tool_call_params(strict=strict) 217 case _: 218 raise_exhaustive_enum_error(provider.structured_output_mode) 219 220 def json_schema_response_format(self) -> dict[str, Any]: 221 output_schema = self.task().output_schema() 222 return { 223 "response_format": { 224 "type": "json_schema", 225 "json_schema": { 226 "name": "task_response", 227 "schema": output_schema, 228 }, 229 } 230 } 231 232 def tool_call_params(self, strict: bool) -> dict[str, Any]: 233 # Add additional_properties: false to the schema (OpenAI requires this for some models) 234 output_schema = self.task().output_schema() 235 if not isinstance(output_schema, dict): 236 raise ValueError( 237 "Invalid output schema for this task. Can not use tool calls." 238 ) 239 output_schema["additionalProperties"] = False 240 241 function_params = { 242 "name": "task_response", 243 "parameters": output_schema, 244 } 245 # This should be on, but we allow setting function_calling_weak for APIs that don't support it. 246 if strict: 247 function_params["strict"] = True 248 249 return { 250 "tools": [ 251 { 252 "type": "function", 253 "function": function_params, 254 } 255 ], 256 "tool_choice": { 257 "type": "function", 258 "function": {"name": "task_response"}, 259 }, 260 } 261 262 def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]: 263 # TODO P1: Don't love having this logic here. But it's a usability improvement 264 # so better to keep it than exclude it. Should figure out how I want to isolate 265 # this sort of logic so it's config driven and can be overridden 266 267 extra_body = {} 268 provider_options = {} 269 270 if provider.thinking_level is not None: 271 extra_body["reasoning_effort"] = provider.thinking_level 272 273 if provider.require_openrouter_reasoning: 274 # https://openrouter.ai/docs/use-cases/reasoning-tokens 275 extra_body["reasoning"] = { 276 "exclude": False, 277 } 278 279 if provider.anthropic_extended_thinking: 280 extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000} 281 282 if provider.r1_openrouter_options: 283 # Require providers that support the reasoning parameter 284 provider_options["require_parameters"] = True 285 # Prefer R1 providers with reasonable perf/quants 286 provider_options["order"] = ["Fireworks", "Together"] 287 # R1 providers with unreasonable quants 288 provider_options["ignore"] = ["DeepInfra"] 289 290 # Only set of this request is to get logprobs. 291 if ( 292 provider.logprobs_openrouter_options 293 and self.base_adapter_config.top_logprobs is not None 294 ): 295 # Don't let OpenRouter choose a provider that doesn't support logprobs. 296 provider_options["require_parameters"] = True 297 # DeepInfra silently fails to return logprobs consistently. 298 provider_options["ignore"] = ["DeepInfra"] 299 300 if provider.openrouter_skip_required_parameters: 301 # Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params. 302 provider_options["require_parameters"] = False 303 304 if len(provider_options) > 0: 305 extra_body["provider"] = provider_options 306 307 return extra_body 308 309 def litellm_model_id(self) -> str: 310 # The model ID is an interesting combination of format and url endpoint. 311 # It specifics the provider URL/host, but this is overridden if you manually set an api url 312 313 if self._litellm_model_id: 314 return self._litellm_model_id 315 316 provider = self.model_provider() 317 if not provider.model_id: 318 raise ValueError("Model ID is required for OpenAI compatible models") 319 320 litellm_provider_name: str | None = None 321 is_custom = False 322 match provider.name: 323 case ModelProviderName.openrouter: 324 litellm_provider_name = "openrouter" 325 case ModelProviderName.openai: 326 litellm_provider_name = "openai" 327 case ModelProviderName.groq: 328 litellm_provider_name = "groq" 329 case ModelProviderName.anthropic: 330 litellm_provider_name = "anthropic" 331 case ModelProviderName.ollama: 332 # We don't let litellm use the Ollama API and muck with our requests. We use Ollama's OpenAI compatible API. 333 # This is because we're setting detailed features like response_format=json_schema and want lower level control. 334 is_custom = True 335 case ModelProviderName.gemini_api: 336 litellm_provider_name = "gemini" 337 case ModelProviderName.fireworks_ai: 338 litellm_provider_name = "fireworks_ai" 339 case ModelProviderName.amazon_bedrock: 340 litellm_provider_name = "bedrock" 341 case ModelProviderName.azure_openai: 342 litellm_provider_name = "azure" 343 case ModelProviderName.huggingface: 344 litellm_provider_name = "huggingface" 345 case ModelProviderName.vertex: 346 litellm_provider_name = "vertex_ai" 347 case ModelProviderName.together_ai: 348 litellm_provider_name = "together_ai" 349 case ModelProviderName.openai_compatible: 350 is_custom = True 351 case ModelProviderName.kiln_custom_registry: 352 is_custom = True 353 case ModelProviderName.kiln_fine_tune: 354 is_custom = True 355 case _: 356 raise_exhaustive_enum_error(provider.name) 357 358 if is_custom: 359 if self._api_base is None: 360 raise ValueError( 361 "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)" 362 ) 363 # Use openai as it's only used for format, not url 364 litellm_provider_name = "openai" 365 366 # Sholdn't be possible but keep type checker happy 367 if litellm_provider_name is None: 368 raise ValueError( 369 f"Provider name could not lookup valid litellm provider ID {provider.model_id}" 370 ) 371 372 self._litellm_model_id = litellm_provider_name + "/" + provider.model_id 373 return self._litellm_model_id 374 375 async def build_completion_kwargs( 376 self, 377 provider: KilnModelProvider, 378 messages: list[dict[str, Any]], 379 top_logprobs: int | None, 380 skip_response_format: bool = False, 381 ) -> dict[str, Any]: 382 extra_body = self.build_extra_body(provider) 383 384 # Merge all parameters into a single kwargs dict for litellm 385 completion_kwargs = { 386 "model": self.litellm_model_id(), 387 "messages": messages, 388 "api_base": self._api_base, 389 "headers": self._headers, 390 **extra_body, 391 **self._additional_body_options, 392 } 393 394 if not skip_response_format: 395 # Response format: json_schema, json_instructions, json_mode, function_calling, etc 396 response_format_options = await self.response_format_options() 397 completion_kwargs.update(response_format_options) 398 399 if top_logprobs is not None: 400 completion_kwargs["logprobs"] = True 401 completion_kwargs["top_logprobs"] = top_logprobs 402 403 return completion_kwargs 404 405 def usage_from_response(self, response: ModelResponse) -> Usage | None: 406 litellm_usage = response.get("usage", None) 407 cost = response._hidden_params.get("response_cost", None) 408 if not litellm_usage and not cost: 409 return None 410 411 usage = Usage() 412 413 if litellm_usage and isinstance(litellm_usage, LiteLlmUsage): 414 usage.input_tokens = litellm_usage.get("prompt_tokens", None) 415 usage.output_tokens = litellm_usage.get("completion_tokens", None) 416 usage.total_tokens = litellm_usage.get("total_tokens", None) 417 else: 418 logger.warning( 419 f"Unexpected usage format from litellm: {litellm_usage}. Expected Usage object, got {type(litellm_usage)}" 420 ) 421 422 if isinstance(cost, float): 423 usage.cost = cost 424 elif cost is not None: 425 # None is allowed, but no other types are expected 426 logger.warning( 427 f"Unexpected cost format from litellm: {cost}. Expected float, got {type(cost)}" 428 ) 429 430 return usage
logger =
<Logger kiln_ai.adapters.model_adapters.litellm_adapter (WARNING)>
30class LiteLlmAdapter(BaseAdapter): 31 def __init__( 32 self, 33 config: LiteLlmConfig, 34 kiln_task: datamodel.Task, 35 prompt_id: PromptId | None = None, 36 base_adapter_config: AdapterConfig | None = None, 37 ): 38 self.config = config 39 self._additional_body_options = config.additional_body_options 40 self._api_base = config.base_url 41 self._headers = config.default_headers 42 self._litellm_model_id: str | None = None 43 44 run_config = RunConfig( 45 task=kiln_task, 46 model_name=config.model_name, 47 model_provider_name=config.provider_name, 48 prompt_id=prompt_id or PromptGenerators.SIMPLE, 49 ) 50 51 super().__init__( 52 run_config=run_config, 53 config=base_adapter_config, 54 ) 55 56 async def _run(self, input: Dict | str) -> tuple[RunOutput, Usage | None]: 57 provider = self.model_provider() 58 if not provider.model_id: 59 raise ValueError("Model ID is required for OpenAI compatible models") 60 61 intermediate_outputs: dict[str, str] = {} 62 prompt = self.build_prompt() 63 user_msg = self.prompt_builder.build_user_message(input) 64 messages = [ 65 {"role": "system", "content": prompt}, 66 {"role": "user", "content": user_msg}, 67 ] 68 69 run_strategy, cot_prompt = self.run_strategy() 70 71 if run_strategy == "cot_as_message": 72 # Used for reasoning-capable models that can output thinking and structured format 73 if not cot_prompt: 74 raise ValueError("cot_prompt is required for cot_as_message strategy") 75 messages.append({"role": "system", "content": cot_prompt}) 76 elif run_strategy == "cot_two_call": 77 if not cot_prompt: 78 raise ValueError("cot_prompt is required for cot_two_call strategy") 79 messages.append({"role": "system", "content": cot_prompt}) 80 81 # First call for chain of thought 82 # No response format as this request is for "thinking" in plain text 83 # No logprobs as only needed for final answer 84 completion_kwargs = await self.build_completion_kwargs( 85 provider, messages, None, skip_response_format=True 86 ) 87 cot_response = await litellm.acompletion(**completion_kwargs) 88 if ( 89 not isinstance(cot_response, ModelResponse) 90 or not cot_response.choices 91 or len(cot_response.choices) == 0 92 or not isinstance(cot_response.choices[0], Choices) 93 ): 94 raise RuntimeError( 95 f"Expected ModelResponse with Choices, got {type(cot_response)}." 96 ) 97 cot_content = cot_response.choices[0].message.content 98 if cot_content is not None: 99 intermediate_outputs["chain_of_thought"] = cot_content 100 101 messages.extend( 102 [ 103 {"role": "assistant", "content": cot_content or ""}, 104 {"role": "user", "content": COT_FINAL_ANSWER_PROMPT}, 105 ] 106 ) 107 108 # Make the API call using litellm 109 completion_kwargs = await self.build_completion_kwargs( 110 provider, messages, self.base_adapter_config.top_logprobs 111 ) 112 response = await litellm.acompletion(**completion_kwargs) 113 114 if not isinstance(response, ModelResponse): 115 raise RuntimeError(f"Expected ModelResponse, got {type(response)}.") 116 117 # Maybe remove this? There is no error attribute on the response object. 118 # # Keeping in typesafe way as we added it for a reason, but should investigate what that was and if it still applies. 119 if hasattr(response, "error") and response.__getattribute__("error"): 120 raise RuntimeError( 121 f"LLM API returned an error: {response.__getattribute__('error')}" 122 ) 123 124 if ( 125 not response.choices 126 or len(response.choices) == 0 127 or not isinstance(response.choices[0], Choices) 128 ): 129 raise RuntimeError( 130 "No message content returned in the response from LLM API" 131 ) 132 133 message = response.choices[0].message 134 logprobs = ( 135 response.choices[0].logprobs 136 if hasattr(response.choices[0], "logprobs") 137 and isinstance(response.choices[0].logprobs, ChoiceLogprobs) 138 else None 139 ) 140 141 # Check logprobs worked, if requested 142 if self.base_adapter_config.top_logprobs is not None and logprobs is None: 143 raise RuntimeError("Logprobs were required, but no logprobs were returned.") 144 145 # Save reasoning if it exists and was parsed by LiteLLM (or openrouter, or anyone upstream) 146 if ( 147 hasattr(message, "reasoning_content") 148 and message.reasoning_content 149 and len(message.reasoning_content.strip()) > 0 150 ): 151 intermediate_outputs["reasoning"] = message.reasoning_content.strip() 152 153 # the string content of the response 154 response_content = message.content 155 156 # Fallback: Use args of first tool call to task_response if it exists 157 if ( 158 not response_content 159 and hasattr(message, "tool_calls") 160 and message.tool_calls 161 ): 162 tool_call = next( 163 ( 164 tool_call 165 for tool_call in message.tool_calls 166 if tool_call.function.name == "task_response" 167 ), 168 None, 169 ) 170 if tool_call: 171 response_content = tool_call.function.arguments 172 173 if not isinstance(response_content, str): 174 raise RuntimeError(f"response is not a string: {response_content}") 175 176 return RunOutput( 177 output=response_content, 178 intermediate_outputs=intermediate_outputs, 179 output_logprobs=logprobs, 180 ), self.usage_from_response(response) 181 182 def adapter_name(self) -> str: 183 return "kiln_openai_compatible_adapter" 184 185 async def response_format_options(self) -> dict[str, Any]: 186 # Unstructured if task isn't structured 187 if not self.has_structured_output(): 188 return {} 189 190 provider = self.model_provider() 191 match provider.structured_output_mode: 192 case StructuredOutputMode.json_mode: 193 return {"response_format": {"type": "json_object"}} 194 case StructuredOutputMode.json_schema: 195 return self.json_schema_response_format() 196 case StructuredOutputMode.function_calling_weak: 197 return self.tool_call_params(strict=False) 198 case StructuredOutputMode.function_calling: 199 return self.tool_call_params(strict=True) 200 case StructuredOutputMode.json_instructions: 201 # JSON instructions dynamically injected in prompt, not the API response format. Do not ask for json_object (see option below). 202 return {} 203 case StructuredOutputMode.json_custom_instructions: 204 # JSON instructions statically injected in system prompt, not the API response format. Do not ask for json_object (see option above). 205 return {} 206 case StructuredOutputMode.json_instruction_and_object: 207 # We set response_format to json_object and also set json instructions in the prompt 208 return {"response_format": {"type": "json_object"}} 209 case StructuredOutputMode.default: 210 if provider.name == ModelProviderName.ollama: 211 # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs 212 return self.json_schema_response_format() 213 else: 214 # Default to function calling -- it's older than the other modes. Higher compatibility. 215 # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI. 216 strict = provider.name == ModelProviderName.openai 217 return self.tool_call_params(strict=strict) 218 case _: 219 raise_exhaustive_enum_error(provider.structured_output_mode) 220 221 def json_schema_response_format(self) -> dict[str, Any]: 222 output_schema = self.task().output_schema() 223 return { 224 "response_format": { 225 "type": "json_schema", 226 "json_schema": { 227 "name": "task_response", 228 "schema": output_schema, 229 }, 230 } 231 } 232 233 def tool_call_params(self, strict: bool) -> dict[str, Any]: 234 # Add additional_properties: false to the schema (OpenAI requires this for some models) 235 output_schema = self.task().output_schema() 236 if not isinstance(output_schema, dict): 237 raise ValueError( 238 "Invalid output schema for this task. Can not use tool calls." 239 ) 240 output_schema["additionalProperties"] = False 241 242 function_params = { 243 "name": "task_response", 244 "parameters": output_schema, 245 } 246 # This should be on, but we allow setting function_calling_weak for APIs that don't support it. 247 if strict: 248 function_params["strict"] = True 249 250 return { 251 "tools": [ 252 { 253 "type": "function", 254 "function": function_params, 255 } 256 ], 257 "tool_choice": { 258 "type": "function", 259 "function": {"name": "task_response"}, 260 }, 261 } 262 263 def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]: 264 # TODO P1: Don't love having this logic here. But it's a usability improvement 265 # so better to keep it than exclude it. Should figure out how I want to isolate 266 # this sort of logic so it's config driven and can be overridden 267 268 extra_body = {} 269 provider_options = {} 270 271 if provider.thinking_level is not None: 272 extra_body["reasoning_effort"] = provider.thinking_level 273 274 if provider.require_openrouter_reasoning: 275 # https://openrouter.ai/docs/use-cases/reasoning-tokens 276 extra_body["reasoning"] = { 277 "exclude": False, 278 } 279 280 if provider.anthropic_extended_thinking: 281 extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000} 282 283 if provider.r1_openrouter_options: 284 # Require providers that support the reasoning parameter 285 provider_options["require_parameters"] = True 286 # Prefer R1 providers with reasonable perf/quants 287 provider_options["order"] = ["Fireworks", "Together"] 288 # R1 providers with unreasonable quants 289 provider_options["ignore"] = ["DeepInfra"] 290 291 # Only set of this request is to get logprobs. 292 if ( 293 provider.logprobs_openrouter_options 294 and self.base_adapter_config.top_logprobs is not None 295 ): 296 # Don't let OpenRouter choose a provider that doesn't support logprobs. 297 provider_options["require_parameters"] = True 298 # DeepInfra silently fails to return logprobs consistently. 299 provider_options["ignore"] = ["DeepInfra"] 300 301 if provider.openrouter_skip_required_parameters: 302 # Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params. 303 provider_options["require_parameters"] = False 304 305 if len(provider_options) > 0: 306 extra_body["provider"] = provider_options 307 308 return extra_body 309 310 def litellm_model_id(self) -> str: 311 # The model ID is an interesting combination of format and url endpoint. 312 # It specifics the provider URL/host, but this is overridden if you manually set an api url 313 314 if self._litellm_model_id: 315 return self._litellm_model_id 316 317 provider = self.model_provider() 318 if not provider.model_id: 319 raise ValueError("Model ID is required for OpenAI compatible models") 320 321 litellm_provider_name: str | None = None 322 is_custom = False 323 match provider.name: 324 case ModelProviderName.openrouter: 325 litellm_provider_name = "openrouter" 326 case ModelProviderName.openai: 327 litellm_provider_name = "openai" 328 case ModelProviderName.groq: 329 litellm_provider_name = "groq" 330 case ModelProviderName.anthropic: 331 litellm_provider_name = "anthropic" 332 case ModelProviderName.ollama: 333 # We don't let litellm use the Ollama API and muck with our requests. We use Ollama's OpenAI compatible API. 334 # This is because we're setting detailed features like response_format=json_schema and want lower level control. 335 is_custom = True 336 case ModelProviderName.gemini_api: 337 litellm_provider_name = "gemini" 338 case ModelProviderName.fireworks_ai: 339 litellm_provider_name = "fireworks_ai" 340 case ModelProviderName.amazon_bedrock: 341 litellm_provider_name = "bedrock" 342 case ModelProviderName.azure_openai: 343 litellm_provider_name = "azure" 344 case ModelProviderName.huggingface: 345 litellm_provider_name = "huggingface" 346 case ModelProviderName.vertex: 347 litellm_provider_name = "vertex_ai" 348 case ModelProviderName.together_ai: 349 litellm_provider_name = "together_ai" 350 case ModelProviderName.openai_compatible: 351 is_custom = True 352 case ModelProviderName.kiln_custom_registry: 353 is_custom = True 354 case ModelProviderName.kiln_fine_tune: 355 is_custom = True 356 case _: 357 raise_exhaustive_enum_error(provider.name) 358 359 if is_custom: 360 if self._api_base is None: 361 raise ValueError( 362 "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)" 363 ) 364 # Use openai as it's only used for format, not url 365 litellm_provider_name = "openai" 366 367 # Sholdn't be possible but keep type checker happy 368 if litellm_provider_name is None: 369 raise ValueError( 370 f"Provider name could not lookup valid litellm provider ID {provider.model_id}" 371 ) 372 373 self._litellm_model_id = litellm_provider_name + "/" + provider.model_id 374 return self._litellm_model_id 375 376 async def build_completion_kwargs( 377 self, 378 provider: KilnModelProvider, 379 messages: list[dict[str, Any]], 380 top_logprobs: int | None, 381 skip_response_format: bool = False, 382 ) -> dict[str, Any]: 383 extra_body = self.build_extra_body(provider) 384 385 # Merge all parameters into a single kwargs dict for litellm 386 completion_kwargs = { 387 "model": self.litellm_model_id(), 388 "messages": messages, 389 "api_base": self._api_base, 390 "headers": self._headers, 391 **extra_body, 392 **self._additional_body_options, 393 } 394 395 if not skip_response_format: 396 # Response format: json_schema, json_instructions, json_mode, function_calling, etc 397 response_format_options = await self.response_format_options() 398 completion_kwargs.update(response_format_options) 399 400 if top_logprobs is not None: 401 completion_kwargs["logprobs"] = True 402 completion_kwargs["top_logprobs"] = top_logprobs 403 404 return completion_kwargs 405 406 def usage_from_response(self, response: ModelResponse) -> Usage | None: 407 litellm_usage = response.get("usage", None) 408 cost = response._hidden_params.get("response_cost", None) 409 if not litellm_usage and not cost: 410 return None 411 412 usage = Usage() 413 414 if litellm_usage and isinstance(litellm_usage, LiteLlmUsage): 415 usage.input_tokens = litellm_usage.get("prompt_tokens", None) 416 usage.output_tokens = litellm_usage.get("completion_tokens", None) 417 usage.total_tokens = litellm_usage.get("total_tokens", None) 418 else: 419 logger.warning( 420 f"Unexpected usage format from litellm: {litellm_usage}. Expected Usage object, got {type(litellm_usage)}" 421 ) 422 423 if isinstance(cost, float): 424 usage.cost = cost 425 elif cost is not None: 426 # None is allowed, but no other types are expected 427 logger.warning( 428 f"Unexpected cost format from litellm: {cost}. Expected float, got {type(cost)}" 429 ) 430 431 return usage
Base class for AI model adapters that handle task execution.
This abstract class provides the foundation for implementing model-specific adapters that can process tasks with structured or unstructured inputs/outputs. It handles input/output validation, prompt building, and run tracking.
Attributes: prompt_builder (BasePromptBuilder): Builder for constructing prompts for the model kiln_task (Task): The task configuration and metadata output_schema (dict | None): JSON schema for validating structured outputs input_schema (dict | None): JSON schema for validating structured inputs
LiteLlmAdapter( config: kiln_ai.adapters.model_adapters.litellm_config.LiteLlmConfig, kiln_task: kiln_ai.datamodel.Task, prompt_id: Optional[Annotated[str, AfterValidator(func=<function <lambda>>)]] = None, base_adapter_config: kiln_ai.adapters.model_adapters.base_adapter.AdapterConfig | None = None)
31 def __init__( 32 self, 33 config: LiteLlmConfig, 34 kiln_task: datamodel.Task, 35 prompt_id: PromptId | None = None, 36 base_adapter_config: AdapterConfig | None = None, 37 ): 38 self.config = config 39 self._additional_body_options = config.additional_body_options 40 self._api_base = config.base_url 41 self._headers = config.default_headers 42 self._litellm_model_id: str | None = None 43 44 run_config = RunConfig( 45 task=kiln_task, 46 model_name=config.model_name, 47 model_provider_name=config.provider_name, 48 prompt_id=prompt_id or PromptGenerators.SIMPLE, 49 ) 50 51 super().__init__( 52 run_config=run_config, 53 config=base_adapter_config, 54 )
async def
response_format_options(self) -> dict[str, typing.Any]:
185 async def response_format_options(self) -> dict[str, Any]: 186 # Unstructured if task isn't structured 187 if not self.has_structured_output(): 188 return {} 189 190 provider = self.model_provider() 191 match provider.structured_output_mode: 192 case StructuredOutputMode.json_mode: 193 return {"response_format": {"type": "json_object"}} 194 case StructuredOutputMode.json_schema: 195 return self.json_schema_response_format() 196 case StructuredOutputMode.function_calling_weak: 197 return self.tool_call_params(strict=False) 198 case StructuredOutputMode.function_calling: 199 return self.tool_call_params(strict=True) 200 case StructuredOutputMode.json_instructions: 201 # JSON instructions dynamically injected in prompt, not the API response format. Do not ask for json_object (see option below). 202 return {} 203 case StructuredOutputMode.json_custom_instructions: 204 # JSON instructions statically injected in system prompt, not the API response format. Do not ask for json_object (see option above). 205 return {} 206 case StructuredOutputMode.json_instruction_and_object: 207 # We set response_format to json_object and also set json instructions in the prompt 208 return {"response_format": {"type": "json_object"}} 209 case StructuredOutputMode.default: 210 if provider.name == ModelProviderName.ollama: 211 # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs 212 return self.json_schema_response_format() 213 else: 214 # Default to function calling -- it's older than the other modes. Higher compatibility. 215 # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI. 216 strict = provider.name == ModelProviderName.openai 217 return self.tool_call_params(strict=strict) 218 case _: 219 raise_exhaustive_enum_error(provider.structured_output_mode)
def
tool_call_params(self, strict: bool) -> dict[str, typing.Any]:
233 def tool_call_params(self, strict: bool) -> dict[str, Any]: 234 # Add additional_properties: false to the schema (OpenAI requires this for some models) 235 output_schema = self.task().output_schema() 236 if not isinstance(output_schema, dict): 237 raise ValueError( 238 "Invalid output schema for this task. Can not use tool calls." 239 ) 240 output_schema["additionalProperties"] = False 241 242 function_params = { 243 "name": "task_response", 244 "parameters": output_schema, 245 } 246 # This should be on, but we allow setting function_calling_weak for APIs that don't support it. 247 if strict: 248 function_params["strict"] = True 249 250 return { 251 "tools": [ 252 { 253 "type": "function", 254 "function": function_params, 255 } 256 ], 257 "tool_choice": { 258 "type": "function", 259 "function": {"name": "task_response"}, 260 }, 261 }
def
build_extra_body( self, provider: kiln_ai.adapters.ml_model_list.KilnModelProvider) -> dict[str, typing.Any]:
263 def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]: 264 # TODO P1: Don't love having this logic here. But it's a usability improvement 265 # so better to keep it than exclude it. Should figure out how I want to isolate 266 # this sort of logic so it's config driven and can be overridden 267 268 extra_body = {} 269 provider_options = {} 270 271 if provider.thinking_level is not None: 272 extra_body["reasoning_effort"] = provider.thinking_level 273 274 if provider.require_openrouter_reasoning: 275 # https://openrouter.ai/docs/use-cases/reasoning-tokens 276 extra_body["reasoning"] = { 277 "exclude": False, 278 } 279 280 if provider.anthropic_extended_thinking: 281 extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000} 282 283 if provider.r1_openrouter_options: 284 # Require providers that support the reasoning parameter 285 provider_options["require_parameters"] = True 286 # Prefer R1 providers with reasonable perf/quants 287 provider_options["order"] = ["Fireworks", "Together"] 288 # R1 providers with unreasonable quants 289 provider_options["ignore"] = ["DeepInfra"] 290 291 # Only set of this request is to get logprobs. 292 if ( 293 provider.logprobs_openrouter_options 294 and self.base_adapter_config.top_logprobs is not None 295 ): 296 # Don't let OpenRouter choose a provider that doesn't support logprobs. 297 provider_options["require_parameters"] = True 298 # DeepInfra silently fails to return logprobs consistently. 299 provider_options["ignore"] = ["DeepInfra"] 300 301 if provider.openrouter_skip_required_parameters: 302 # Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params. 303 provider_options["require_parameters"] = False 304 305 if len(provider_options) > 0: 306 extra_body["provider"] = provider_options 307 308 return extra_body
def
litellm_model_id(self) -> str:
310 def litellm_model_id(self) -> str: 311 # The model ID is an interesting combination of format and url endpoint. 312 # It specifics the provider URL/host, but this is overridden if you manually set an api url 313 314 if self._litellm_model_id: 315 return self._litellm_model_id 316 317 provider = self.model_provider() 318 if not provider.model_id: 319 raise ValueError("Model ID is required for OpenAI compatible models") 320 321 litellm_provider_name: str | None = None 322 is_custom = False 323 match provider.name: 324 case ModelProviderName.openrouter: 325 litellm_provider_name = "openrouter" 326 case ModelProviderName.openai: 327 litellm_provider_name = "openai" 328 case ModelProviderName.groq: 329 litellm_provider_name = "groq" 330 case ModelProviderName.anthropic: 331 litellm_provider_name = "anthropic" 332 case ModelProviderName.ollama: 333 # We don't let litellm use the Ollama API and muck with our requests. We use Ollama's OpenAI compatible API. 334 # This is because we're setting detailed features like response_format=json_schema and want lower level control. 335 is_custom = True 336 case ModelProviderName.gemini_api: 337 litellm_provider_name = "gemini" 338 case ModelProviderName.fireworks_ai: 339 litellm_provider_name = "fireworks_ai" 340 case ModelProviderName.amazon_bedrock: 341 litellm_provider_name = "bedrock" 342 case ModelProviderName.azure_openai: 343 litellm_provider_name = "azure" 344 case ModelProviderName.huggingface: 345 litellm_provider_name = "huggingface" 346 case ModelProviderName.vertex: 347 litellm_provider_name = "vertex_ai" 348 case ModelProviderName.together_ai: 349 litellm_provider_name = "together_ai" 350 case ModelProviderName.openai_compatible: 351 is_custom = True 352 case ModelProviderName.kiln_custom_registry: 353 is_custom = True 354 case ModelProviderName.kiln_fine_tune: 355 is_custom = True 356 case _: 357 raise_exhaustive_enum_error(provider.name) 358 359 if is_custom: 360 if self._api_base is None: 361 raise ValueError( 362 "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)" 363 ) 364 # Use openai as it's only used for format, not url 365 litellm_provider_name = "openai" 366 367 # Sholdn't be possible but keep type checker happy 368 if litellm_provider_name is None: 369 raise ValueError( 370 f"Provider name could not lookup valid litellm provider ID {provider.model_id}" 371 ) 372 373 self._litellm_model_id = litellm_provider_name + "/" + provider.model_id 374 return self._litellm_model_id
async def
build_completion_kwargs( self, provider: kiln_ai.adapters.ml_model_list.KilnModelProvider, messages: list[dict[str, typing.Any]], top_logprobs: int | None, skip_response_format: bool = False) -> dict[str, typing.Any]:
376 async def build_completion_kwargs( 377 self, 378 provider: KilnModelProvider, 379 messages: list[dict[str, Any]], 380 top_logprobs: int | None, 381 skip_response_format: bool = False, 382 ) -> dict[str, Any]: 383 extra_body = self.build_extra_body(provider) 384 385 # Merge all parameters into a single kwargs dict for litellm 386 completion_kwargs = { 387 "model": self.litellm_model_id(), 388 "messages": messages, 389 "api_base": self._api_base, 390 "headers": self._headers, 391 **extra_body, 392 **self._additional_body_options, 393 } 394 395 if not skip_response_format: 396 # Response format: json_schema, json_instructions, json_mode, function_calling, etc 397 response_format_options = await self.response_format_options() 398 completion_kwargs.update(response_format_options) 399 400 if top_logprobs is not None: 401 completion_kwargs["logprobs"] = True 402 completion_kwargs["top_logprobs"] = top_logprobs 403 404 return completion_kwargs
def
usage_from_response( self, response: litellm.types.utils.ModelResponse) -> kiln_ai.datamodel.Usage | None:
406 def usage_from_response(self, response: ModelResponse) -> Usage | None: 407 litellm_usage = response.get("usage", None) 408 cost = response._hidden_params.get("response_cost", None) 409 if not litellm_usage and not cost: 410 return None 411 412 usage = Usage() 413 414 if litellm_usage and isinstance(litellm_usage, LiteLlmUsage): 415 usage.input_tokens = litellm_usage.get("prompt_tokens", None) 416 usage.output_tokens = litellm_usage.get("completion_tokens", None) 417 usage.total_tokens = litellm_usage.get("total_tokens", None) 418 else: 419 logger.warning( 420 f"Unexpected usage format from litellm: {litellm_usage}. Expected Usage object, got {type(litellm_usage)}" 421 ) 422 423 if isinstance(cost, float): 424 usage.cost = cost 425 elif cost is not None: 426 # None is allowed, but no other types are expected 427 logger.warning( 428 f"Unexpected cost format from litellm: {cost}. Expected float, got {type(cost)}" 429 ) 430 431 return usage