kiln_ai.adapters.model_adapters.litellm_adapter
1import copy 2import json 3import logging 4from dataclasses import dataclass 5from typing import Any, Dict, List, Tuple, TypeAlias, Union 6 7import litellm 8from litellm.types.utils import ( 9 ChatCompletionMessageToolCall, 10 ChoiceLogprobs, 11 Choices, 12 ModelResponse, 13) 14from litellm.types.utils import Message as LiteLLMMessage 15from litellm.types.utils import Usage as LiteLlmUsage 16from openai.types.chat.chat_completion_message_tool_call_param import ( 17 ChatCompletionMessageToolCallParam, 18) 19 20import kiln_ai.datamodel as datamodel 21from kiln_ai.adapters.ml_model_list import ( 22 KilnModelProvider, 23 ModelProviderName, 24 StructuredOutputMode, 25) 26from kiln_ai.adapters.model_adapters.base_adapter import ( 27 AdapterConfig, 28 BaseAdapter, 29 RunOutput, 30 Usage, 31) 32from kiln_ai.adapters.model_adapters.litellm_config import LiteLlmConfig 33from kiln_ai.datamodel.datamodel_enums import InputType 34from kiln_ai.datamodel.json_schema import validate_schema_with_value_error 35from kiln_ai.tools.base_tool import ( 36 KilnToolInterface, 37 ToolCallContext, 38 ToolCallDefinition, 39) 40from kiln_ai.tools.kiln_task_tool import KilnTaskToolResult 41from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error 42from kiln_ai.utils.litellm import get_litellm_provider_info 43from kiln_ai.utils.open_ai_types import ( 44 ChatCompletionAssistantMessageParamWrapper, 45 ChatCompletionMessageParam, 46 ChatCompletionToolMessageParamWrapper, 47) 48 49MAX_CALLS_PER_TURN = 10 50MAX_TOOL_CALLS_PER_TURN = 30 51 52logger = logging.getLogger(__name__) 53 54ChatCompletionMessageIncludingLiteLLM: TypeAlias = Union[ 55 ChatCompletionMessageParam, LiteLLMMessage 56] 57 58 59@dataclass 60class ModelTurnResult: 61 assistant_message: str 62 all_messages: list[ChatCompletionMessageIncludingLiteLLM] 63 model_response: ModelResponse | None 64 model_choice: Choices | None 65 usage: Usage 66 67 68class LiteLlmAdapter(BaseAdapter): 69 def __init__( 70 self, 71 config: LiteLlmConfig, 72 kiln_task: datamodel.Task, 73 base_adapter_config: AdapterConfig | None = None, 74 ): 75 self.config = config 76 self._additional_body_options = config.additional_body_options 77 self._api_base = config.base_url 78 self._headers = config.default_headers 79 self._litellm_model_id: str | None = None 80 self._cached_available_tools: list[KilnToolInterface] | None = None 81 82 super().__init__( 83 task=kiln_task, 84 run_config=config.run_config_properties, 85 config=base_adapter_config, 86 ) 87 88 async def _run_model_turn( 89 self, 90 provider: KilnModelProvider, 91 prior_messages: list[ChatCompletionMessageIncludingLiteLLM], 92 top_logprobs: int | None, 93 skip_response_format: bool, 94 ) -> ModelTurnResult: 95 """ 96 Call the model for a single top level turn: from user message to agent message. 97 98 It may make handle iterations of tool calls between the user/agent message if needed. 99 """ 100 101 usage = Usage() 102 messages = list(prior_messages) 103 tool_calls_count = 0 104 105 while tool_calls_count < MAX_TOOL_CALLS_PER_TURN: 106 # Build completion kwargs for tool calls 107 completion_kwargs = await self.build_completion_kwargs( 108 provider, 109 # Pass a copy, as acompletion mutates objects and breaks types. 110 copy.deepcopy(messages), 111 top_logprobs, 112 skip_response_format, 113 ) 114 115 # Make the completion call 116 model_response, response_choice = await self.acompletion_checking_response( 117 **completion_kwargs 118 ) 119 120 # count the usage 121 usage += self.usage_from_response(model_response) 122 123 # Extract content and tool calls 124 if not hasattr(response_choice, "message"): 125 raise ValueError("Response choice has no message") 126 content = response_choice.message.content 127 tool_calls = response_choice.message.tool_calls 128 if not content and not tool_calls: 129 raise ValueError( 130 "Model returned an assistant message, but no content or tool calls. This is not supported." 131 ) 132 133 # Add message to messages, so it can be used in the next turn 134 messages.append(response_choice.message) 135 136 # Process tool calls if any 137 if tool_calls and len(tool_calls) > 0: 138 ( 139 assistant_message_from_toolcall, 140 tool_call_messages, 141 ) = await self.process_tool_calls(tool_calls) 142 143 # Add tool call results to messages 144 messages.extend(tool_call_messages) 145 146 # If task_response tool was called, we're done 147 if assistant_message_from_toolcall is not None: 148 return ModelTurnResult( 149 assistant_message=assistant_message_from_toolcall, 150 all_messages=messages, 151 model_response=model_response, 152 model_choice=response_choice, 153 usage=usage, 154 ) 155 156 # If there were tool calls, increment counter and continue 157 if tool_call_messages: 158 tool_calls_count += 1 159 continue 160 161 # If no tool calls, return the content as final output 162 if content: 163 return ModelTurnResult( 164 assistant_message=content, 165 all_messages=messages, 166 model_response=model_response, 167 model_choice=response_choice, 168 usage=usage, 169 ) 170 171 # If we get here with no content and no tool calls, break 172 raise RuntimeError( 173 "Model returned neither content nor tool calls. It must return at least one of these." 174 ) 175 176 raise RuntimeError( 177 f"Too many tool calls ({tool_calls_count}). Stopping iteration to avoid using too many tokens." 178 ) 179 180 async def _run(self, input: InputType) -> tuple[RunOutput, Usage | None]: 181 usage = Usage() 182 183 provider = self.model_provider() 184 if not provider.model_id: 185 raise ValueError("Model ID is required for OpenAI compatible models") 186 187 chat_formatter = self.build_chat_formatter(input) 188 messages: list[ChatCompletionMessageIncludingLiteLLM] = [] 189 190 prior_output: str | None = None 191 final_choice: Choices | None = None 192 turns = 0 193 194 while True: 195 turns += 1 196 if turns > MAX_CALLS_PER_TURN: 197 raise RuntimeError( 198 f"Too many turns ({turns}). Stopping iteration to avoid using too many tokens." 199 ) 200 201 turn = chat_formatter.next_turn(prior_output) 202 if turn is None: 203 # No next turn, we're done 204 break 205 206 # Add messages from the turn to chat history 207 for message in turn.messages: 208 if message.content is None: 209 raise ValueError("Empty message content isn't allowed") 210 # pyright incorrectly warns about this, but it's valid so we can ignore. It can't handle the multi-value role. 211 messages.append({"role": message.role, "content": message.content}) # type: ignore 212 213 skip_response_format = not turn.final_call 214 turn_result = await self._run_model_turn( 215 provider, 216 messages, 217 self.base_adapter_config.top_logprobs if turn.final_call else None, 218 skip_response_format, 219 ) 220 221 usage += turn_result.usage 222 223 prior_output = turn_result.assistant_message 224 messages = turn_result.all_messages 225 final_choice = turn_result.model_choice 226 227 if not prior_output: 228 raise RuntimeError("No assistant message/output returned from model") 229 230 logprobs = self._extract_and_validate_logprobs(final_choice) 231 232 # Save COT/reasoning if it exists. May be a message, or may be parsed by LiteLLM (or openrouter, or anyone upstream) 233 intermediate_outputs = chat_formatter.intermediate_outputs() 234 self._extract_reasoning_to_intermediate_outputs( 235 final_choice, intermediate_outputs 236 ) 237 238 if not isinstance(prior_output, str): 239 raise RuntimeError(f"assistant message is not a string: {prior_output}") 240 241 trace = self.all_messages_to_trace(messages) 242 output = RunOutput( 243 output=prior_output, 244 intermediate_outputs=intermediate_outputs, 245 output_logprobs=logprobs, 246 trace=trace, 247 ) 248 249 return output, usage 250 251 def _extract_and_validate_logprobs( 252 self, final_choice: Choices | None 253 ) -> ChoiceLogprobs | None: 254 """ 255 Extract logprobs from the final choice and validate they exist if required. 256 """ 257 logprobs = None 258 if ( 259 final_choice is not None 260 and hasattr(final_choice, "logprobs") 261 and isinstance(final_choice.logprobs, ChoiceLogprobs) 262 ): 263 logprobs = final_choice.logprobs 264 265 # Check logprobs worked, if required 266 if self.base_adapter_config.top_logprobs is not None and logprobs is None: 267 raise RuntimeError("Logprobs were required, but no logprobs were returned.") 268 269 return logprobs 270 271 def _extract_reasoning_to_intermediate_outputs( 272 self, final_choice: Choices | None, intermediate_outputs: Dict[str, Any] 273 ) -> None: 274 """Extract reasoning content from model choice and add to intermediate outputs if present.""" 275 if ( 276 final_choice is not None 277 and hasattr(final_choice, "message") 278 and hasattr(final_choice.message, "reasoning_content") 279 ): 280 reasoning_content = final_choice.message.reasoning_content 281 if reasoning_content is not None: 282 stripped_reasoning_content = reasoning_content.strip() 283 if len(stripped_reasoning_content) > 0: 284 intermediate_outputs["reasoning"] = stripped_reasoning_content 285 286 async def acompletion_checking_response( 287 self, **kwargs 288 ) -> Tuple[ModelResponse, Choices]: 289 response = await litellm.acompletion(**kwargs) 290 if ( 291 not isinstance(response, ModelResponse) 292 or not response.choices 293 or len(response.choices) == 0 294 or not isinstance(response.choices[0], Choices) 295 ): 296 raise RuntimeError( 297 f"Expected ModelResponse with Choices, got {type(response)}." 298 ) 299 return response, response.choices[0] 300 301 def adapter_name(self) -> str: 302 return "kiln_openai_compatible_adapter" 303 304 async def response_format_options(self) -> dict[str, Any]: 305 # Unstructured if task isn't structured 306 if not self.has_structured_output(): 307 return {} 308 309 structured_output_mode = self.run_config.structured_output_mode 310 311 match structured_output_mode: 312 case StructuredOutputMode.json_mode: 313 return {"response_format": {"type": "json_object"}} 314 case StructuredOutputMode.json_schema: 315 return self.json_schema_response_format() 316 case StructuredOutputMode.function_calling_weak: 317 return self.tool_call_params(strict=False) 318 case StructuredOutputMode.function_calling: 319 return self.tool_call_params(strict=True) 320 case StructuredOutputMode.json_instructions: 321 # JSON instructions dynamically injected in prompt, not the API response format. Do not ask for json_object (see option below). 322 return {} 323 case StructuredOutputMode.json_custom_instructions: 324 # JSON instructions statically injected in system prompt, not the API response format. Do not ask for json_object (see option above). 325 return {} 326 case StructuredOutputMode.json_instruction_and_object: 327 # We set response_format to json_object and also set json instructions in the prompt 328 return {"response_format": {"type": "json_object"}} 329 case StructuredOutputMode.default: 330 provider_name = self.run_config.model_provider_name 331 if provider_name == ModelProviderName.ollama: 332 # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs 333 return self.json_schema_response_format() 334 elif provider_name == ModelProviderName.docker_model_runner: 335 # Docker Model Runner uses OpenAI-compatible API with JSON schema support 336 return self.json_schema_response_format() 337 else: 338 # Default to function calling -- it's older than the other modes. Higher compatibility. 339 # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI. 340 strict = provider_name == ModelProviderName.openai 341 return self.tool_call_params(strict=strict) 342 case StructuredOutputMode.unknown: 343 # See above, but this case should never happen. 344 raise ValueError("Structured output mode is unknown.") 345 case _: 346 raise_exhaustive_enum_error(structured_output_mode) 347 348 def json_schema_response_format(self) -> dict[str, Any]: 349 output_schema = self.task.output_schema() 350 return { 351 "response_format": { 352 "type": "json_schema", 353 "json_schema": { 354 "name": "task_response", 355 "schema": output_schema, 356 }, 357 } 358 } 359 360 def tool_call_params(self, strict: bool) -> dict[str, Any]: 361 # Add additional_properties: false to the schema (OpenAI requires this for some models) 362 output_schema = self.task.output_schema() 363 if not isinstance(output_schema, dict): 364 raise ValueError( 365 "Invalid output schema for this task. Can not use tool calls." 366 ) 367 output_schema["additionalProperties"] = False 368 369 function_params = { 370 "name": "task_response", 371 "parameters": output_schema, 372 } 373 # This should be on, but we allow setting function_calling_weak for APIs that don't support it. 374 if strict: 375 function_params["strict"] = True 376 377 return { 378 "tools": [ 379 { 380 "type": "function", 381 "function": function_params, 382 } 383 ], 384 "tool_choice": { 385 "type": "function", 386 "function": {"name": "task_response"}, 387 }, 388 } 389 390 def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]: 391 # Don't love having this logic here. But it's worth the usability improvement 392 # so better to keep it than exclude it. Should figure out how I want to isolate 393 # this sort of logic so it's config driven and can be overridden 394 395 extra_body = {} 396 provider_options = {} 397 398 if provider.thinking_level is not None: 399 extra_body["reasoning_effort"] = provider.thinking_level 400 401 if provider.require_openrouter_reasoning: 402 # https://openrouter.ai/docs/use-cases/reasoning-tokens 403 extra_body["reasoning"] = { 404 "exclude": False, 405 } 406 407 if provider.gemini_reasoning_enabled: 408 extra_body["reasoning"] = { 409 "enabled": True, 410 } 411 412 if provider.name == ModelProviderName.openrouter: 413 # Ask OpenRouter to include usage in the response (cost) 414 extra_body["usage"] = {"include": True} 415 416 # Set a default provider order for more deterministic routing. 417 # OpenRouter will ignore providers that don't support the model. 418 # Special cases below (like R1) can override this order. 419 # allow_fallbacks is true by default, but we can override it here. 420 provider_options["order"] = [ 421 "fireworks", 422 "parasail", 423 "together", 424 "deepinfra", 425 "novita", 426 "groq", 427 "amazon-bedrock", 428 "azure", 429 "nebius", 430 ] 431 432 if provider.anthropic_extended_thinking: 433 extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000} 434 435 if provider.r1_openrouter_options: 436 # Require providers that support the reasoning parameter 437 provider_options["require_parameters"] = True 438 # Prefer R1 providers with reasonable perf/quants 439 provider_options["order"] = ["fireworks", "together"] 440 # R1 providers with unreasonable quants 441 provider_options["ignore"] = ["deepinfra"] 442 443 # Only set of this request is to get logprobs. 444 if ( 445 provider.logprobs_openrouter_options 446 and self.base_adapter_config.top_logprobs is not None 447 ): 448 # Don't let OpenRouter choose a provider that doesn't support logprobs. 449 provider_options["require_parameters"] = True 450 # DeepInfra silently fails to return logprobs consistently. 451 provider_options["ignore"] = ["deepinfra"] 452 453 if provider.openrouter_skip_required_parameters: 454 # Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params. 455 provider_options["require_parameters"] = False 456 457 # Siliconflow uses a bool flag for thinking, for some models 458 if provider.siliconflow_enable_thinking is not None: 459 extra_body["enable_thinking"] = provider.siliconflow_enable_thinking 460 461 if len(provider_options) > 0: 462 extra_body["provider"] = provider_options 463 464 return extra_body 465 466 def litellm_model_id(self) -> str: 467 # The model ID is an interesting combination of format and url endpoint. 468 # It specifics the provider URL/host, but this is overridden if you manually set an api url 469 if self._litellm_model_id: 470 return self._litellm_model_id 471 472 litellm_provider_info = get_litellm_provider_info(self.model_provider()) 473 if litellm_provider_info.is_custom and self._api_base is None: 474 raise ValueError( 475 "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)" 476 ) 477 478 self._litellm_model_id = litellm_provider_info.litellm_model_id 479 return self._litellm_model_id 480 481 async def build_completion_kwargs( 482 self, 483 provider: KilnModelProvider, 484 messages: list[ChatCompletionMessageIncludingLiteLLM], 485 top_logprobs: int | None, 486 skip_response_format: bool = False, 487 ) -> dict[str, Any]: 488 extra_body = self.build_extra_body(provider) 489 490 # Merge all parameters into a single kwargs dict for litellm 491 completion_kwargs = { 492 "model": self.litellm_model_id(), 493 "messages": messages, 494 "api_base": self._api_base, 495 "headers": self._headers, 496 "temperature": self.run_config.temperature, 497 "top_p": self.run_config.top_p, 498 # This drops params that are not supported by the model. Only openai params like top_p, temperature -- not litellm params like model, etc. 499 # Not all models and providers support all openai params (for example, o3 doesn't support top_p) 500 # Better to ignore them than to fail the model call. 501 # https://docs.litellm.ai/docs/completion/input 502 "drop_params": True, 503 **extra_body, 504 **self._additional_body_options, 505 } 506 507 tool_calls = await self.litellm_tools() 508 has_tools = len(tool_calls) > 0 509 if has_tools: 510 completion_kwargs["tools"] = tool_calls 511 completion_kwargs["tool_choice"] = "auto" 512 513 # Special condition for Claude Opus 4.1 and Sonnet 4.5, where we can only specify top_p or temp, not both. 514 # Remove default values (1.0) prioritizing anything the user customized, then error with helpful message if they are both custom. 515 if provider.temp_top_p_exclusive: 516 if "top_p" in completion_kwargs and completion_kwargs["top_p"] == 1.0: 517 del completion_kwargs["top_p"] 518 if ( 519 "temperature" in completion_kwargs 520 and completion_kwargs["temperature"] == 1.0 521 ): 522 del completion_kwargs["temperature"] 523 if "top_p" in completion_kwargs and "temperature" in completion_kwargs: 524 raise ValueError( 525 "top_p and temperature can not both have custom values for this model. This is a restriction from the model provider. Please set only one of them to a custom value (not 1.0)." 526 ) 527 528 if not skip_response_format: 529 # Response format: json_schema, json_instructions, json_mode, function_calling, etc 530 response_format_options = await self.response_format_options() 531 532 # Check for a conflict between tools and response format using tools 533 # We could reconsider this. Model could be able to choose between a final answer or a tool call on any turn. However, good models for tools tend to also support json_schea, so do we need to support both? If we do, merge them, and consider auto vs forced when merging (only forced for final, auto for merged). 534 if has_tools and "tools" in response_format_options: 535 raise ValueError( 536 "Function calling/tools can't be used as the JSON response format if you're also using tools. Please select a different structured output mode." 537 ) 538 539 completion_kwargs.update(response_format_options) 540 541 if top_logprobs is not None: 542 completion_kwargs["logprobs"] = True 543 completion_kwargs["top_logprobs"] = top_logprobs 544 545 return completion_kwargs 546 547 def usage_from_response(self, response: ModelResponse) -> Usage: 548 litellm_usage = response.get("usage", None) 549 550 # LiteLLM isn't consistent in how it returns the cost. 551 cost = response._hidden_params.get("response_cost", None) 552 if cost is None and litellm_usage: 553 cost = litellm_usage.get("cost", None) 554 555 usage = Usage() 556 557 if not litellm_usage and not cost: 558 return usage 559 560 if litellm_usage and isinstance(litellm_usage, LiteLlmUsage): 561 usage.input_tokens = litellm_usage.get("prompt_tokens", None) 562 usage.output_tokens = litellm_usage.get("completion_tokens", None) 563 usage.total_tokens = litellm_usage.get("total_tokens", None) 564 else: 565 logger.warning( 566 f"Unexpected usage format from litellm: {litellm_usage}. Expected Usage object, got {type(litellm_usage)}" 567 ) 568 569 if isinstance(cost, float): 570 usage.cost = cost 571 elif cost is not None: 572 # None is allowed, but no other types are expected 573 logger.warning( 574 f"Unexpected cost format from litellm: {cost}. Expected float, got {type(cost)}" 575 ) 576 577 return usage 578 579 async def cached_available_tools(self) -> list[KilnToolInterface]: 580 if self._cached_available_tools is None: 581 self._cached_available_tools = await self.available_tools() 582 return self._cached_available_tools 583 584 async def litellm_tools(self) -> list[ToolCallDefinition]: 585 available_tools = await self.cached_available_tools() 586 587 # LiteLLM takes the standard OpenAI-compatible tool call format 588 return [await tool.toolcall_definition() for tool in available_tools] 589 590 async def process_tool_calls( 591 self, tool_calls: list[ChatCompletionMessageToolCall] | None 592 ) -> tuple[str | None, list[ChatCompletionToolMessageParamWrapper]]: 593 if tool_calls is None: 594 return None, [] 595 596 assistant_output_from_toolcall: str | None = None 597 tool_call_response_messages: list[ChatCompletionToolMessageParamWrapper] = [] 598 599 for tool_call in tool_calls: 600 # Kiln "task_response" tool is used for returning structured output via tool calls. 601 # Load the output from the tool call. Also 602 if tool_call.function.name == "task_response": 603 assistant_output_from_toolcall = tool_call.function.arguments 604 continue 605 606 # Process normal tool calls (not the "task_response" tool) 607 tool_name = tool_call.function.name 608 tool = None 609 for tool_option in await self.cached_available_tools(): 610 if await tool_option.name() == tool_name: 611 tool = tool_option 612 break 613 if not tool: 614 raise RuntimeError( 615 f"A tool named '{tool_name}' was invoked by a model, but was not available." 616 ) 617 618 # Parse the arguments and validate them against the tool's schema 619 try: 620 parsed_args = json.loads(tool_call.function.arguments) 621 except json.JSONDecodeError: 622 raise RuntimeError( 623 f"Failed to parse arguments for tool '{tool_name}' (should be JSON): {tool_call.function.arguments}" 624 ) 625 try: 626 tool_call_definition = await tool.toolcall_definition() 627 json_schema = json.dumps(tool_call_definition["function"]["parameters"]) 628 validate_schema_with_value_error(parsed_args, json_schema) 629 except Exception as e: 630 raise RuntimeError( 631 f"Failed to validate arguments for tool '{tool_name}'. The arguments didn't match the tool's schema. The arguments were: {parsed_args}\n The error was: {e}" 632 ) from e 633 634 # Create context with the calling task's allow_saving setting 635 context = ToolCallContext( 636 allow_saving=self.base_adapter_config.allow_saving 637 ) 638 result = await tool.run(context, **parsed_args) 639 640 tool_call_response_messages.append( 641 ChatCompletionToolMessageParamWrapper( 642 role="tool", 643 tool_call_id=tool_call.id, 644 content=result.output, 645 kiln_task_tool_data=result.kiln_task_tool_data 646 if isinstance(result, KilnTaskToolResult) 647 else None, 648 ) 649 ) 650 651 if ( 652 assistant_output_from_toolcall is not None 653 and len(tool_call_response_messages) > 0 654 ): 655 raise RuntimeError( 656 "Model asked for impossible combination: task_response tool call and other tool calls were both provided in the same turn. This is not supported as it means the model asked us to both return task_response results (ending the turn) and run new tools calls to send back to the model. If the model makes this mistake often, try a difference structured data model like JSON schema, where this is impossible." 657 ) 658 659 return assistant_output_from_toolcall, tool_call_response_messages 660 661 def litellm_message_to_trace_message( 662 self, raw_message: LiteLLMMessage 663 ) -> ChatCompletionAssistantMessageParamWrapper: 664 """ 665 Convert a LiteLLM Message object to an OpenAI compatible message, our ChatCompletionAssistantMessageParamWrapper 666 """ 667 message: ChatCompletionAssistantMessageParamWrapper = { 668 "role": "assistant", 669 } 670 if raw_message.role != "assistant": 671 raise ValueError( 672 "Model returned a message with a role other than assistant. This is not supported." 673 ) 674 675 if hasattr(raw_message, "content"): 676 message["content"] = raw_message.content 677 if hasattr(raw_message, "reasoning_content"): 678 message["reasoning_content"] = raw_message.reasoning_content 679 if hasattr(raw_message, "tool_calls"): 680 # Convert ChatCompletionMessageToolCall to ChatCompletionMessageToolCallParam 681 open_ai_tool_calls: List[ChatCompletionMessageToolCallParam] = [] 682 for litellm_tool_call in raw_message.tool_calls or []: 683 # Optional in the SDK for streaming responses, but should never be None at this point. 684 if litellm_tool_call.function.name is None: 685 raise ValueError( 686 "The model requested a tool call, without providing a function name (required)." 687 ) 688 open_ai_tool_calls.append( 689 ChatCompletionMessageToolCallParam( 690 id=litellm_tool_call.id, 691 type="function", 692 function={ 693 "name": litellm_tool_call.function.name, 694 "arguments": litellm_tool_call.function.arguments, 695 }, 696 ) 697 ) 698 if len(open_ai_tool_calls) > 0: 699 message["tool_calls"] = open_ai_tool_calls 700 701 if not message.get("content") and not message.get("tool_calls"): 702 raise ValueError( 703 "Model returned an assistant message, but no content or tool calls. This is not supported." 704 ) 705 706 return message 707 708 def all_messages_to_trace( 709 self, messages: list[ChatCompletionMessageIncludingLiteLLM] 710 ) -> list[ChatCompletionMessageParam]: 711 """ 712 Internally we allow LiteLLM Message objects, but for trace we need OpenAI compatible types. Replace LiteLLM Message objects with OpenAI compatible types. 713 """ 714 trace: list[ChatCompletionMessageParam] = [] 715 for message in messages: 716 if isinstance(message, LiteLLMMessage): 717 trace.append(self.litellm_message_to_trace_message(message)) 718 else: 719 trace.append(message) 720 return trace
MAX_CALLS_PER_TURN =
10
MAX_TOOL_CALLS_PER_TURN =
30
logger =
<Logger kiln_ai.adapters.model_adapters.litellm_adapter (WARNING)>
ChatCompletionMessageIncludingLiteLLM: TypeAlias =
Union[openai.types.chat.chat_completion_developer_message_param.ChatCompletionDeveloperMessageParam, openai.types.chat.chat_completion_system_message_param.ChatCompletionSystemMessageParam, openai.types.chat.chat_completion_user_message_param.ChatCompletionUserMessageParam, kiln_ai.utils.open_ai_types.ChatCompletionAssistantMessageParamWrapper, kiln_ai.utils.open_ai_types.ChatCompletionToolMessageParamWrapper, openai.types.chat.chat_completion_function_message_param.ChatCompletionFunctionMessageParam, litellm.types.utils.Message]
@dataclass
class
ModelTurnResult:
60@dataclass 61class ModelTurnResult: 62 assistant_message: str 63 all_messages: list[ChatCompletionMessageIncludingLiteLLM] 64 model_response: ModelResponse | None 65 model_choice: Choices | None 66 usage: Usage
ModelTurnResult( assistant_message: str, all_messages: list[typing.Union[openai.types.chat.chat_completion_developer_message_param.ChatCompletionDeveloperMessageParam, openai.types.chat.chat_completion_system_message_param.ChatCompletionSystemMessageParam, openai.types.chat.chat_completion_user_message_param.ChatCompletionUserMessageParam, kiln_ai.utils.open_ai_types.ChatCompletionAssistantMessageParamWrapper, kiln_ai.utils.open_ai_types.ChatCompletionToolMessageParamWrapper, openai.types.chat.chat_completion_function_message_param.ChatCompletionFunctionMessageParam, litellm.types.utils.Message]], model_response: litellm.types.utils.ModelResponse | None, model_choice: litellm.types.utils.Choices | None, usage: kiln_ai.datamodel.Usage)
all_messages: list[typing.Union[openai.types.chat.chat_completion_developer_message_param.ChatCompletionDeveloperMessageParam, openai.types.chat.chat_completion_system_message_param.ChatCompletionSystemMessageParam, openai.types.chat.chat_completion_user_message_param.ChatCompletionUserMessageParam, kiln_ai.utils.open_ai_types.ChatCompletionAssistantMessageParamWrapper, kiln_ai.utils.open_ai_types.ChatCompletionToolMessageParamWrapper, openai.types.chat.chat_completion_function_message_param.ChatCompletionFunctionMessageParam, litellm.types.utils.Message]]
usage: kiln_ai.datamodel.Usage
69class LiteLlmAdapter(BaseAdapter): 70 def __init__( 71 self, 72 config: LiteLlmConfig, 73 kiln_task: datamodel.Task, 74 base_adapter_config: AdapterConfig | None = None, 75 ): 76 self.config = config 77 self._additional_body_options = config.additional_body_options 78 self._api_base = config.base_url 79 self._headers = config.default_headers 80 self._litellm_model_id: str | None = None 81 self._cached_available_tools: list[KilnToolInterface] | None = None 82 83 super().__init__( 84 task=kiln_task, 85 run_config=config.run_config_properties, 86 config=base_adapter_config, 87 ) 88 89 async def _run_model_turn( 90 self, 91 provider: KilnModelProvider, 92 prior_messages: list[ChatCompletionMessageIncludingLiteLLM], 93 top_logprobs: int | None, 94 skip_response_format: bool, 95 ) -> ModelTurnResult: 96 """ 97 Call the model for a single top level turn: from user message to agent message. 98 99 It may make handle iterations of tool calls between the user/agent message if needed. 100 """ 101 102 usage = Usage() 103 messages = list(prior_messages) 104 tool_calls_count = 0 105 106 while tool_calls_count < MAX_TOOL_CALLS_PER_TURN: 107 # Build completion kwargs for tool calls 108 completion_kwargs = await self.build_completion_kwargs( 109 provider, 110 # Pass a copy, as acompletion mutates objects and breaks types. 111 copy.deepcopy(messages), 112 top_logprobs, 113 skip_response_format, 114 ) 115 116 # Make the completion call 117 model_response, response_choice = await self.acompletion_checking_response( 118 **completion_kwargs 119 ) 120 121 # count the usage 122 usage += self.usage_from_response(model_response) 123 124 # Extract content and tool calls 125 if not hasattr(response_choice, "message"): 126 raise ValueError("Response choice has no message") 127 content = response_choice.message.content 128 tool_calls = response_choice.message.tool_calls 129 if not content and not tool_calls: 130 raise ValueError( 131 "Model returned an assistant message, but no content or tool calls. This is not supported." 132 ) 133 134 # Add message to messages, so it can be used in the next turn 135 messages.append(response_choice.message) 136 137 # Process tool calls if any 138 if tool_calls and len(tool_calls) > 0: 139 ( 140 assistant_message_from_toolcall, 141 tool_call_messages, 142 ) = await self.process_tool_calls(tool_calls) 143 144 # Add tool call results to messages 145 messages.extend(tool_call_messages) 146 147 # If task_response tool was called, we're done 148 if assistant_message_from_toolcall is not None: 149 return ModelTurnResult( 150 assistant_message=assistant_message_from_toolcall, 151 all_messages=messages, 152 model_response=model_response, 153 model_choice=response_choice, 154 usage=usage, 155 ) 156 157 # If there were tool calls, increment counter and continue 158 if tool_call_messages: 159 tool_calls_count += 1 160 continue 161 162 # If no tool calls, return the content as final output 163 if content: 164 return ModelTurnResult( 165 assistant_message=content, 166 all_messages=messages, 167 model_response=model_response, 168 model_choice=response_choice, 169 usage=usage, 170 ) 171 172 # If we get here with no content and no tool calls, break 173 raise RuntimeError( 174 "Model returned neither content nor tool calls. It must return at least one of these." 175 ) 176 177 raise RuntimeError( 178 f"Too many tool calls ({tool_calls_count}). Stopping iteration to avoid using too many tokens." 179 ) 180 181 async def _run(self, input: InputType) -> tuple[RunOutput, Usage | None]: 182 usage = Usage() 183 184 provider = self.model_provider() 185 if not provider.model_id: 186 raise ValueError("Model ID is required for OpenAI compatible models") 187 188 chat_formatter = self.build_chat_formatter(input) 189 messages: list[ChatCompletionMessageIncludingLiteLLM] = [] 190 191 prior_output: str | None = None 192 final_choice: Choices | None = None 193 turns = 0 194 195 while True: 196 turns += 1 197 if turns > MAX_CALLS_PER_TURN: 198 raise RuntimeError( 199 f"Too many turns ({turns}). Stopping iteration to avoid using too many tokens." 200 ) 201 202 turn = chat_formatter.next_turn(prior_output) 203 if turn is None: 204 # No next turn, we're done 205 break 206 207 # Add messages from the turn to chat history 208 for message in turn.messages: 209 if message.content is None: 210 raise ValueError("Empty message content isn't allowed") 211 # pyright incorrectly warns about this, but it's valid so we can ignore. It can't handle the multi-value role. 212 messages.append({"role": message.role, "content": message.content}) # type: ignore 213 214 skip_response_format = not turn.final_call 215 turn_result = await self._run_model_turn( 216 provider, 217 messages, 218 self.base_adapter_config.top_logprobs if turn.final_call else None, 219 skip_response_format, 220 ) 221 222 usage += turn_result.usage 223 224 prior_output = turn_result.assistant_message 225 messages = turn_result.all_messages 226 final_choice = turn_result.model_choice 227 228 if not prior_output: 229 raise RuntimeError("No assistant message/output returned from model") 230 231 logprobs = self._extract_and_validate_logprobs(final_choice) 232 233 # Save COT/reasoning if it exists. May be a message, or may be parsed by LiteLLM (or openrouter, or anyone upstream) 234 intermediate_outputs = chat_formatter.intermediate_outputs() 235 self._extract_reasoning_to_intermediate_outputs( 236 final_choice, intermediate_outputs 237 ) 238 239 if not isinstance(prior_output, str): 240 raise RuntimeError(f"assistant message is not a string: {prior_output}") 241 242 trace = self.all_messages_to_trace(messages) 243 output = RunOutput( 244 output=prior_output, 245 intermediate_outputs=intermediate_outputs, 246 output_logprobs=logprobs, 247 trace=trace, 248 ) 249 250 return output, usage 251 252 def _extract_and_validate_logprobs( 253 self, final_choice: Choices | None 254 ) -> ChoiceLogprobs | None: 255 """ 256 Extract logprobs from the final choice and validate they exist if required. 257 """ 258 logprobs = None 259 if ( 260 final_choice is not None 261 and hasattr(final_choice, "logprobs") 262 and isinstance(final_choice.logprobs, ChoiceLogprobs) 263 ): 264 logprobs = final_choice.logprobs 265 266 # Check logprobs worked, if required 267 if self.base_adapter_config.top_logprobs is not None and logprobs is None: 268 raise RuntimeError("Logprobs were required, but no logprobs were returned.") 269 270 return logprobs 271 272 def _extract_reasoning_to_intermediate_outputs( 273 self, final_choice: Choices | None, intermediate_outputs: Dict[str, Any] 274 ) -> None: 275 """Extract reasoning content from model choice and add to intermediate outputs if present.""" 276 if ( 277 final_choice is not None 278 and hasattr(final_choice, "message") 279 and hasattr(final_choice.message, "reasoning_content") 280 ): 281 reasoning_content = final_choice.message.reasoning_content 282 if reasoning_content is not None: 283 stripped_reasoning_content = reasoning_content.strip() 284 if len(stripped_reasoning_content) > 0: 285 intermediate_outputs["reasoning"] = stripped_reasoning_content 286 287 async def acompletion_checking_response( 288 self, **kwargs 289 ) -> Tuple[ModelResponse, Choices]: 290 response = await litellm.acompletion(**kwargs) 291 if ( 292 not isinstance(response, ModelResponse) 293 or not response.choices 294 or len(response.choices) == 0 295 or not isinstance(response.choices[0], Choices) 296 ): 297 raise RuntimeError( 298 f"Expected ModelResponse with Choices, got {type(response)}." 299 ) 300 return response, response.choices[0] 301 302 def adapter_name(self) -> str: 303 return "kiln_openai_compatible_adapter" 304 305 async def response_format_options(self) -> dict[str, Any]: 306 # Unstructured if task isn't structured 307 if not self.has_structured_output(): 308 return {} 309 310 structured_output_mode = self.run_config.structured_output_mode 311 312 match structured_output_mode: 313 case StructuredOutputMode.json_mode: 314 return {"response_format": {"type": "json_object"}} 315 case StructuredOutputMode.json_schema: 316 return self.json_schema_response_format() 317 case StructuredOutputMode.function_calling_weak: 318 return self.tool_call_params(strict=False) 319 case StructuredOutputMode.function_calling: 320 return self.tool_call_params(strict=True) 321 case StructuredOutputMode.json_instructions: 322 # JSON instructions dynamically injected in prompt, not the API response format. Do not ask for json_object (see option below). 323 return {} 324 case StructuredOutputMode.json_custom_instructions: 325 # JSON instructions statically injected in system prompt, not the API response format. Do not ask for json_object (see option above). 326 return {} 327 case StructuredOutputMode.json_instruction_and_object: 328 # We set response_format to json_object and also set json instructions in the prompt 329 return {"response_format": {"type": "json_object"}} 330 case StructuredOutputMode.default: 331 provider_name = self.run_config.model_provider_name 332 if provider_name == ModelProviderName.ollama: 333 # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs 334 return self.json_schema_response_format() 335 elif provider_name == ModelProviderName.docker_model_runner: 336 # Docker Model Runner uses OpenAI-compatible API with JSON schema support 337 return self.json_schema_response_format() 338 else: 339 # Default to function calling -- it's older than the other modes. Higher compatibility. 340 # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI. 341 strict = provider_name == ModelProviderName.openai 342 return self.tool_call_params(strict=strict) 343 case StructuredOutputMode.unknown: 344 # See above, but this case should never happen. 345 raise ValueError("Structured output mode is unknown.") 346 case _: 347 raise_exhaustive_enum_error(structured_output_mode) 348 349 def json_schema_response_format(self) -> dict[str, Any]: 350 output_schema = self.task.output_schema() 351 return { 352 "response_format": { 353 "type": "json_schema", 354 "json_schema": { 355 "name": "task_response", 356 "schema": output_schema, 357 }, 358 } 359 } 360 361 def tool_call_params(self, strict: bool) -> dict[str, Any]: 362 # Add additional_properties: false to the schema (OpenAI requires this for some models) 363 output_schema = self.task.output_schema() 364 if not isinstance(output_schema, dict): 365 raise ValueError( 366 "Invalid output schema for this task. Can not use tool calls." 367 ) 368 output_schema["additionalProperties"] = False 369 370 function_params = { 371 "name": "task_response", 372 "parameters": output_schema, 373 } 374 # This should be on, but we allow setting function_calling_weak for APIs that don't support it. 375 if strict: 376 function_params["strict"] = True 377 378 return { 379 "tools": [ 380 { 381 "type": "function", 382 "function": function_params, 383 } 384 ], 385 "tool_choice": { 386 "type": "function", 387 "function": {"name": "task_response"}, 388 }, 389 } 390 391 def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]: 392 # Don't love having this logic here. But it's worth the usability improvement 393 # so better to keep it than exclude it. Should figure out how I want to isolate 394 # this sort of logic so it's config driven and can be overridden 395 396 extra_body = {} 397 provider_options = {} 398 399 if provider.thinking_level is not None: 400 extra_body["reasoning_effort"] = provider.thinking_level 401 402 if provider.require_openrouter_reasoning: 403 # https://openrouter.ai/docs/use-cases/reasoning-tokens 404 extra_body["reasoning"] = { 405 "exclude": False, 406 } 407 408 if provider.gemini_reasoning_enabled: 409 extra_body["reasoning"] = { 410 "enabled": True, 411 } 412 413 if provider.name == ModelProviderName.openrouter: 414 # Ask OpenRouter to include usage in the response (cost) 415 extra_body["usage"] = {"include": True} 416 417 # Set a default provider order for more deterministic routing. 418 # OpenRouter will ignore providers that don't support the model. 419 # Special cases below (like R1) can override this order. 420 # allow_fallbacks is true by default, but we can override it here. 421 provider_options["order"] = [ 422 "fireworks", 423 "parasail", 424 "together", 425 "deepinfra", 426 "novita", 427 "groq", 428 "amazon-bedrock", 429 "azure", 430 "nebius", 431 ] 432 433 if provider.anthropic_extended_thinking: 434 extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000} 435 436 if provider.r1_openrouter_options: 437 # Require providers that support the reasoning parameter 438 provider_options["require_parameters"] = True 439 # Prefer R1 providers with reasonable perf/quants 440 provider_options["order"] = ["fireworks", "together"] 441 # R1 providers with unreasonable quants 442 provider_options["ignore"] = ["deepinfra"] 443 444 # Only set of this request is to get logprobs. 445 if ( 446 provider.logprobs_openrouter_options 447 and self.base_adapter_config.top_logprobs is not None 448 ): 449 # Don't let OpenRouter choose a provider that doesn't support logprobs. 450 provider_options["require_parameters"] = True 451 # DeepInfra silently fails to return logprobs consistently. 452 provider_options["ignore"] = ["deepinfra"] 453 454 if provider.openrouter_skip_required_parameters: 455 # Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params. 456 provider_options["require_parameters"] = False 457 458 # Siliconflow uses a bool flag for thinking, for some models 459 if provider.siliconflow_enable_thinking is not None: 460 extra_body["enable_thinking"] = provider.siliconflow_enable_thinking 461 462 if len(provider_options) > 0: 463 extra_body["provider"] = provider_options 464 465 return extra_body 466 467 def litellm_model_id(self) -> str: 468 # The model ID is an interesting combination of format and url endpoint. 469 # It specifics the provider URL/host, but this is overridden if you manually set an api url 470 if self._litellm_model_id: 471 return self._litellm_model_id 472 473 litellm_provider_info = get_litellm_provider_info(self.model_provider()) 474 if litellm_provider_info.is_custom and self._api_base is None: 475 raise ValueError( 476 "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)" 477 ) 478 479 self._litellm_model_id = litellm_provider_info.litellm_model_id 480 return self._litellm_model_id 481 482 async def build_completion_kwargs( 483 self, 484 provider: KilnModelProvider, 485 messages: list[ChatCompletionMessageIncludingLiteLLM], 486 top_logprobs: int | None, 487 skip_response_format: bool = False, 488 ) -> dict[str, Any]: 489 extra_body = self.build_extra_body(provider) 490 491 # Merge all parameters into a single kwargs dict for litellm 492 completion_kwargs = { 493 "model": self.litellm_model_id(), 494 "messages": messages, 495 "api_base": self._api_base, 496 "headers": self._headers, 497 "temperature": self.run_config.temperature, 498 "top_p": self.run_config.top_p, 499 # This drops params that are not supported by the model. Only openai params like top_p, temperature -- not litellm params like model, etc. 500 # Not all models and providers support all openai params (for example, o3 doesn't support top_p) 501 # Better to ignore them than to fail the model call. 502 # https://docs.litellm.ai/docs/completion/input 503 "drop_params": True, 504 **extra_body, 505 **self._additional_body_options, 506 } 507 508 tool_calls = await self.litellm_tools() 509 has_tools = len(tool_calls) > 0 510 if has_tools: 511 completion_kwargs["tools"] = tool_calls 512 completion_kwargs["tool_choice"] = "auto" 513 514 # Special condition for Claude Opus 4.1 and Sonnet 4.5, where we can only specify top_p or temp, not both. 515 # Remove default values (1.0) prioritizing anything the user customized, then error with helpful message if they are both custom. 516 if provider.temp_top_p_exclusive: 517 if "top_p" in completion_kwargs and completion_kwargs["top_p"] == 1.0: 518 del completion_kwargs["top_p"] 519 if ( 520 "temperature" in completion_kwargs 521 and completion_kwargs["temperature"] == 1.0 522 ): 523 del completion_kwargs["temperature"] 524 if "top_p" in completion_kwargs and "temperature" in completion_kwargs: 525 raise ValueError( 526 "top_p and temperature can not both have custom values for this model. This is a restriction from the model provider. Please set only one of them to a custom value (not 1.0)." 527 ) 528 529 if not skip_response_format: 530 # Response format: json_schema, json_instructions, json_mode, function_calling, etc 531 response_format_options = await self.response_format_options() 532 533 # Check for a conflict between tools and response format using tools 534 # We could reconsider this. Model could be able to choose between a final answer or a tool call on any turn. However, good models for tools tend to also support json_schea, so do we need to support both? If we do, merge them, and consider auto vs forced when merging (only forced for final, auto for merged). 535 if has_tools and "tools" in response_format_options: 536 raise ValueError( 537 "Function calling/tools can't be used as the JSON response format if you're also using tools. Please select a different structured output mode." 538 ) 539 540 completion_kwargs.update(response_format_options) 541 542 if top_logprobs is not None: 543 completion_kwargs["logprobs"] = True 544 completion_kwargs["top_logprobs"] = top_logprobs 545 546 return completion_kwargs 547 548 def usage_from_response(self, response: ModelResponse) -> Usage: 549 litellm_usage = response.get("usage", None) 550 551 # LiteLLM isn't consistent in how it returns the cost. 552 cost = response._hidden_params.get("response_cost", None) 553 if cost is None and litellm_usage: 554 cost = litellm_usage.get("cost", None) 555 556 usage = Usage() 557 558 if not litellm_usage and not cost: 559 return usage 560 561 if litellm_usage and isinstance(litellm_usage, LiteLlmUsage): 562 usage.input_tokens = litellm_usage.get("prompt_tokens", None) 563 usage.output_tokens = litellm_usage.get("completion_tokens", None) 564 usage.total_tokens = litellm_usage.get("total_tokens", None) 565 else: 566 logger.warning( 567 f"Unexpected usage format from litellm: {litellm_usage}. Expected Usage object, got {type(litellm_usage)}" 568 ) 569 570 if isinstance(cost, float): 571 usage.cost = cost 572 elif cost is not None: 573 # None is allowed, but no other types are expected 574 logger.warning( 575 f"Unexpected cost format from litellm: {cost}. Expected float, got {type(cost)}" 576 ) 577 578 return usage 579 580 async def cached_available_tools(self) -> list[KilnToolInterface]: 581 if self._cached_available_tools is None: 582 self._cached_available_tools = await self.available_tools() 583 return self._cached_available_tools 584 585 async def litellm_tools(self) -> list[ToolCallDefinition]: 586 available_tools = await self.cached_available_tools() 587 588 # LiteLLM takes the standard OpenAI-compatible tool call format 589 return [await tool.toolcall_definition() for tool in available_tools] 590 591 async def process_tool_calls( 592 self, tool_calls: list[ChatCompletionMessageToolCall] | None 593 ) -> tuple[str | None, list[ChatCompletionToolMessageParamWrapper]]: 594 if tool_calls is None: 595 return None, [] 596 597 assistant_output_from_toolcall: str | None = None 598 tool_call_response_messages: list[ChatCompletionToolMessageParamWrapper] = [] 599 600 for tool_call in tool_calls: 601 # Kiln "task_response" tool is used for returning structured output via tool calls. 602 # Load the output from the tool call. Also 603 if tool_call.function.name == "task_response": 604 assistant_output_from_toolcall = tool_call.function.arguments 605 continue 606 607 # Process normal tool calls (not the "task_response" tool) 608 tool_name = tool_call.function.name 609 tool = None 610 for tool_option in await self.cached_available_tools(): 611 if await tool_option.name() == tool_name: 612 tool = tool_option 613 break 614 if not tool: 615 raise RuntimeError( 616 f"A tool named '{tool_name}' was invoked by a model, but was not available." 617 ) 618 619 # Parse the arguments and validate them against the tool's schema 620 try: 621 parsed_args = json.loads(tool_call.function.arguments) 622 except json.JSONDecodeError: 623 raise RuntimeError( 624 f"Failed to parse arguments for tool '{tool_name}' (should be JSON): {tool_call.function.arguments}" 625 ) 626 try: 627 tool_call_definition = await tool.toolcall_definition() 628 json_schema = json.dumps(tool_call_definition["function"]["parameters"]) 629 validate_schema_with_value_error(parsed_args, json_schema) 630 except Exception as e: 631 raise RuntimeError( 632 f"Failed to validate arguments for tool '{tool_name}'. The arguments didn't match the tool's schema. The arguments were: {parsed_args}\n The error was: {e}" 633 ) from e 634 635 # Create context with the calling task's allow_saving setting 636 context = ToolCallContext( 637 allow_saving=self.base_adapter_config.allow_saving 638 ) 639 result = await tool.run(context, **parsed_args) 640 641 tool_call_response_messages.append( 642 ChatCompletionToolMessageParamWrapper( 643 role="tool", 644 tool_call_id=tool_call.id, 645 content=result.output, 646 kiln_task_tool_data=result.kiln_task_tool_data 647 if isinstance(result, KilnTaskToolResult) 648 else None, 649 ) 650 ) 651 652 if ( 653 assistant_output_from_toolcall is not None 654 and len(tool_call_response_messages) > 0 655 ): 656 raise RuntimeError( 657 "Model asked for impossible combination: task_response tool call and other tool calls were both provided in the same turn. This is not supported as it means the model asked us to both return task_response results (ending the turn) and run new tools calls to send back to the model. If the model makes this mistake often, try a difference structured data model like JSON schema, where this is impossible." 658 ) 659 660 return assistant_output_from_toolcall, tool_call_response_messages 661 662 def litellm_message_to_trace_message( 663 self, raw_message: LiteLLMMessage 664 ) -> ChatCompletionAssistantMessageParamWrapper: 665 """ 666 Convert a LiteLLM Message object to an OpenAI compatible message, our ChatCompletionAssistantMessageParamWrapper 667 """ 668 message: ChatCompletionAssistantMessageParamWrapper = { 669 "role": "assistant", 670 } 671 if raw_message.role != "assistant": 672 raise ValueError( 673 "Model returned a message with a role other than assistant. This is not supported." 674 ) 675 676 if hasattr(raw_message, "content"): 677 message["content"] = raw_message.content 678 if hasattr(raw_message, "reasoning_content"): 679 message["reasoning_content"] = raw_message.reasoning_content 680 if hasattr(raw_message, "tool_calls"): 681 # Convert ChatCompletionMessageToolCall to ChatCompletionMessageToolCallParam 682 open_ai_tool_calls: List[ChatCompletionMessageToolCallParam] = [] 683 for litellm_tool_call in raw_message.tool_calls or []: 684 # Optional in the SDK for streaming responses, but should never be None at this point. 685 if litellm_tool_call.function.name is None: 686 raise ValueError( 687 "The model requested a tool call, without providing a function name (required)." 688 ) 689 open_ai_tool_calls.append( 690 ChatCompletionMessageToolCallParam( 691 id=litellm_tool_call.id, 692 type="function", 693 function={ 694 "name": litellm_tool_call.function.name, 695 "arguments": litellm_tool_call.function.arguments, 696 }, 697 ) 698 ) 699 if len(open_ai_tool_calls) > 0: 700 message["tool_calls"] = open_ai_tool_calls 701 702 if not message.get("content") and not message.get("tool_calls"): 703 raise ValueError( 704 "Model returned an assistant message, but no content or tool calls. This is not supported." 705 ) 706 707 return message 708 709 def all_messages_to_trace( 710 self, messages: list[ChatCompletionMessageIncludingLiteLLM] 711 ) -> list[ChatCompletionMessageParam]: 712 """ 713 Internally we allow LiteLLM Message objects, but for trace we need OpenAI compatible types. Replace LiteLLM Message objects with OpenAI compatible types. 714 """ 715 trace: list[ChatCompletionMessageParam] = [] 716 for message in messages: 717 if isinstance(message, LiteLLMMessage): 718 trace.append(self.litellm_message_to_trace_message(message)) 719 else: 720 trace.append(message) 721 return trace
Base class for AI model adapters that handle task execution.
This abstract class provides the foundation for implementing model-specific adapters that can process tasks with structured or unstructured inputs/outputs. It handles input/output validation, prompt building, and run tracking.
LiteLlmAdapter( config: kiln_ai.adapters.model_adapters.litellm_config.LiteLlmConfig, kiln_task: kiln_ai.datamodel.Task, base_adapter_config: kiln_ai.adapters.model_adapters.base_adapter.AdapterConfig | None = None)
70 def __init__( 71 self, 72 config: LiteLlmConfig, 73 kiln_task: datamodel.Task, 74 base_adapter_config: AdapterConfig | None = None, 75 ): 76 self.config = config 77 self._additional_body_options = config.additional_body_options 78 self._api_base = config.base_url 79 self._headers = config.default_headers 80 self._litellm_model_id: str | None = None 81 self._cached_available_tools: list[KilnToolInterface] | None = None 82 83 super().__init__( 84 task=kiln_task, 85 run_config=config.run_config_properties, 86 config=base_adapter_config, 87 )
async def
acompletion_checking_response( self, **kwargs) -> Tuple[litellm.types.utils.ModelResponse, litellm.types.utils.Choices]:
287 async def acompletion_checking_response( 288 self, **kwargs 289 ) -> Tuple[ModelResponse, Choices]: 290 response = await litellm.acompletion(**kwargs) 291 if ( 292 not isinstance(response, ModelResponse) 293 or not response.choices 294 or len(response.choices) == 0 295 or not isinstance(response.choices[0], Choices) 296 ): 297 raise RuntimeError( 298 f"Expected ModelResponse with Choices, got {type(response)}." 299 ) 300 return response, response.choices[0]
async def
response_format_options(self) -> dict[str, typing.Any]:
305 async def response_format_options(self) -> dict[str, Any]: 306 # Unstructured if task isn't structured 307 if not self.has_structured_output(): 308 return {} 309 310 structured_output_mode = self.run_config.structured_output_mode 311 312 match structured_output_mode: 313 case StructuredOutputMode.json_mode: 314 return {"response_format": {"type": "json_object"}} 315 case StructuredOutputMode.json_schema: 316 return self.json_schema_response_format() 317 case StructuredOutputMode.function_calling_weak: 318 return self.tool_call_params(strict=False) 319 case StructuredOutputMode.function_calling: 320 return self.tool_call_params(strict=True) 321 case StructuredOutputMode.json_instructions: 322 # JSON instructions dynamically injected in prompt, not the API response format. Do not ask for json_object (see option below). 323 return {} 324 case StructuredOutputMode.json_custom_instructions: 325 # JSON instructions statically injected in system prompt, not the API response format. Do not ask for json_object (see option above). 326 return {} 327 case StructuredOutputMode.json_instruction_and_object: 328 # We set response_format to json_object and also set json instructions in the prompt 329 return {"response_format": {"type": "json_object"}} 330 case StructuredOutputMode.default: 331 provider_name = self.run_config.model_provider_name 332 if provider_name == ModelProviderName.ollama: 333 # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs 334 return self.json_schema_response_format() 335 elif provider_name == ModelProviderName.docker_model_runner: 336 # Docker Model Runner uses OpenAI-compatible API with JSON schema support 337 return self.json_schema_response_format() 338 else: 339 # Default to function calling -- it's older than the other modes. Higher compatibility. 340 # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI. 341 strict = provider_name == ModelProviderName.openai 342 return self.tool_call_params(strict=strict) 343 case StructuredOutputMode.unknown: 344 # See above, but this case should never happen. 345 raise ValueError("Structured output mode is unknown.") 346 case _: 347 raise_exhaustive_enum_error(structured_output_mode)
def
tool_call_params(self, strict: bool) -> dict[str, typing.Any]:
361 def tool_call_params(self, strict: bool) -> dict[str, Any]: 362 # Add additional_properties: false to the schema (OpenAI requires this for some models) 363 output_schema = self.task.output_schema() 364 if not isinstance(output_schema, dict): 365 raise ValueError( 366 "Invalid output schema for this task. Can not use tool calls." 367 ) 368 output_schema["additionalProperties"] = False 369 370 function_params = { 371 "name": "task_response", 372 "parameters": output_schema, 373 } 374 # This should be on, but we allow setting function_calling_weak for APIs that don't support it. 375 if strict: 376 function_params["strict"] = True 377 378 return { 379 "tools": [ 380 { 381 "type": "function", 382 "function": function_params, 383 } 384 ], 385 "tool_choice": { 386 "type": "function", 387 "function": {"name": "task_response"}, 388 }, 389 }
def
build_extra_body( self, provider: kiln_ai.adapters.ml_model_list.KilnModelProvider) -> dict[str, typing.Any]:
391 def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]: 392 # Don't love having this logic here. But it's worth the usability improvement 393 # so better to keep it than exclude it. Should figure out how I want to isolate 394 # this sort of logic so it's config driven and can be overridden 395 396 extra_body = {} 397 provider_options = {} 398 399 if provider.thinking_level is not None: 400 extra_body["reasoning_effort"] = provider.thinking_level 401 402 if provider.require_openrouter_reasoning: 403 # https://openrouter.ai/docs/use-cases/reasoning-tokens 404 extra_body["reasoning"] = { 405 "exclude": False, 406 } 407 408 if provider.gemini_reasoning_enabled: 409 extra_body["reasoning"] = { 410 "enabled": True, 411 } 412 413 if provider.name == ModelProviderName.openrouter: 414 # Ask OpenRouter to include usage in the response (cost) 415 extra_body["usage"] = {"include": True} 416 417 # Set a default provider order for more deterministic routing. 418 # OpenRouter will ignore providers that don't support the model. 419 # Special cases below (like R1) can override this order. 420 # allow_fallbacks is true by default, but we can override it here. 421 provider_options["order"] = [ 422 "fireworks", 423 "parasail", 424 "together", 425 "deepinfra", 426 "novita", 427 "groq", 428 "amazon-bedrock", 429 "azure", 430 "nebius", 431 ] 432 433 if provider.anthropic_extended_thinking: 434 extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000} 435 436 if provider.r1_openrouter_options: 437 # Require providers that support the reasoning parameter 438 provider_options["require_parameters"] = True 439 # Prefer R1 providers with reasonable perf/quants 440 provider_options["order"] = ["fireworks", "together"] 441 # R1 providers with unreasonable quants 442 provider_options["ignore"] = ["deepinfra"] 443 444 # Only set of this request is to get logprobs. 445 if ( 446 provider.logprobs_openrouter_options 447 and self.base_adapter_config.top_logprobs is not None 448 ): 449 # Don't let OpenRouter choose a provider that doesn't support logprobs. 450 provider_options["require_parameters"] = True 451 # DeepInfra silently fails to return logprobs consistently. 452 provider_options["ignore"] = ["deepinfra"] 453 454 if provider.openrouter_skip_required_parameters: 455 # Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params. 456 provider_options["require_parameters"] = False 457 458 # Siliconflow uses a bool flag for thinking, for some models 459 if provider.siliconflow_enable_thinking is not None: 460 extra_body["enable_thinking"] = provider.siliconflow_enable_thinking 461 462 if len(provider_options) > 0: 463 extra_body["provider"] = provider_options 464 465 return extra_body
def
litellm_model_id(self) -> str:
467 def litellm_model_id(self) -> str: 468 # The model ID is an interesting combination of format and url endpoint. 469 # It specifics the provider URL/host, but this is overridden if you manually set an api url 470 if self._litellm_model_id: 471 return self._litellm_model_id 472 473 litellm_provider_info = get_litellm_provider_info(self.model_provider()) 474 if litellm_provider_info.is_custom and self._api_base is None: 475 raise ValueError( 476 "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)" 477 ) 478 479 self._litellm_model_id = litellm_provider_info.litellm_model_id 480 return self._litellm_model_id
async def
build_completion_kwargs( self, provider: kiln_ai.adapters.ml_model_list.KilnModelProvider, messages: list[typing.Union[openai.types.chat.chat_completion_developer_message_param.ChatCompletionDeveloperMessageParam, openai.types.chat.chat_completion_system_message_param.ChatCompletionSystemMessageParam, openai.types.chat.chat_completion_user_message_param.ChatCompletionUserMessageParam, kiln_ai.utils.open_ai_types.ChatCompletionAssistantMessageParamWrapper, kiln_ai.utils.open_ai_types.ChatCompletionToolMessageParamWrapper, openai.types.chat.chat_completion_function_message_param.ChatCompletionFunctionMessageParam, litellm.types.utils.Message]], top_logprobs: int | None, skip_response_format: bool = False) -> dict[str, typing.Any]:
482 async def build_completion_kwargs( 483 self, 484 provider: KilnModelProvider, 485 messages: list[ChatCompletionMessageIncludingLiteLLM], 486 top_logprobs: int | None, 487 skip_response_format: bool = False, 488 ) -> dict[str, Any]: 489 extra_body = self.build_extra_body(provider) 490 491 # Merge all parameters into a single kwargs dict for litellm 492 completion_kwargs = { 493 "model": self.litellm_model_id(), 494 "messages": messages, 495 "api_base": self._api_base, 496 "headers": self._headers, 497 "temperature": self.run_config.temperature, 498 "top_p": self.run_config.top_p, 499 # This drops params that are not supported by the model. Only openai params like top_p, temperature -- not litellm params like model, etc. 500 # Not all models and providers support all openai params (for example, o3 doesn't support top_p) 501 # Better to ignore them than to fail the model call. 502 # https://docs.litellm.ai/docs/completion/input 503 "drop_params": True, 504 **extra_body, 505 **self._additional_body_options, 506 } 507 508 tool_calls = await self.litellm_tools() 509 has_tools = len(tool_calls) > 0 510 if has_tools: 511 completion_kwargs["tools"] = tool_calls 512 completion_kwargs["tool_choice"] = "auto" 513 514 # Special condition for Claude Opus 4.1 and Sonnet 4.5, where we can only specify top_p or temp, not both. 515 # Remove default values (1.0) prioritizing anything the user customized, then error with helpful message if they are both custom. 516 if provider.temp_top_p_exclusive: 517 if "top_p" in completion_kwargs and completion_kwargs["top_p"] == 1.0: 518 del completion_kwargs["top_p"] 519 if ( 520 "temperature" in completion_kwargs 521 and completion_kwargs["temperature"] == 1.0 522 ): 523 del completion_kwargs["temperature"] 524 if "top_p" in completion_kwargs and "temperature" in completion_kwargs: 525 raise ValueError( 526 "top_p and temperature can not both have custom values for this model. This is a restriction from the model provider. Please set only one of them to a custom value (not 1.0)." 527 ) 528 529 if not skip_response_format: 530 # Response format: json_schema, json_instructions, json_mode, function_calling, etc 531 response_format_options = await self.response_format_options() 532 533 # Check for a conflict between tools and response format using tools 534 # We could reconsider this. Model could be able to choose between a final answer or a tool call on any turn. However, good models for tools tend to also support json_schea, so do we need to support both? If we do, merge them, and consider auto vs forced when merging (only forced for final, auto for merged). 535 if has_tools and "tools" in response_format_options: 536 raise ValueError( 537 "Function calling/tools can't be used as the JSON response format if you're also using tools. Please select a different structured output mode." 538 ) 539 540 completion_kwargs.update(response_format_options) 541 542 if top_logprobs is not None: 543 completion_kwargs["logprobs"] = True 544 completion_kwargs["top_logprobs"] = top_logprobs 545 546 return completion_kwargs
def
usage_from_response( self, response: litellm.types.utils.ModelResponse) -> kiln_ai.datamodel.Usage:
548 def usage_from_response(self, response: ModelResponse) -> Usage: 549 litellm_usage = response.get("usage", None) 550 551 # LiteLLM isn't consistent in how it returns the cost. 552 cost = response._hidden_params.get("response_cost", None) 553 if cost is None and litellm_usage: 554 cost = litellm_usage.get("cost", None) 555 556 usage = Usage() 557 558 if not litellm_usage and not cost: 559 return usage 560 561 if litellm_usage and isinstance(litellm_usage, LiteLlmUsage): 562 usage.input_tokens = litellm_usage.get("prompt_tokens", None) 563 usage.output_tokens = litellm_usage.get("completion_tokens", None) 564 usage.total_tokens = litellm_usage.get("total_tokens", None) 565 else: 566 logger.warning( 567 f"Unexpected usage format from litellm: {litellm_usage}. Expected Usage object, got {type(litellm_usage)}" 568 ) 569 570 if isinstance(cost, float): 571 usage.cost = cost 572 elif cost is not None: 573 # None is allowed, but no other types are expected 574 logger.warning( 575 f"Unexpected cost format from litellm: {cost}. Expected float, got {type(cost)}" 576 ) 577 578 return usage
async def
process_tool_calls( self, tool_calls: list[litellm.types.utils.ChatCompletionMessageToolCall] | None) -> tuple[str | None, list[kiln_ai.utils.open_ai_types.ChatCompletionToolMessageParamWrapper]]:
591 async def process_tool_calls( 592 self, tool_calls: list[ChatCompletionMessageToolCall] | None 593 ) -> tuple[str | None, list[ChatCompletionToolMessageParamWrapper]]: 594 if tool_calls is None: 595 return None, [] 596 597 assistant_output_from_toolcall: str | None = None 598 tool_call_response_messages: list[ChatCompletionToolMessageParamWrapper] = [] 599 600 for tool_call in tool_calls: 601 # Kiln "task_response" tool is used for returning structured output via tool calls. 602 # Load the output from the tool call. Also 603 if tool_call.function.name == "task_response": 604 assistant_output_from_toolcall = tool_call.function.arguments 605 continue 606 607 # Process normal tool calls (not the "task_response" tool) 608 tool_name = tool_call.function.name 609 tool = None 610 for tool_option in await self.cached_available_tools(): 611 if await tool_option.name() == tool_name: 612 tool = tool_option 613 break 614 if not tool: 615 raise RuntimeError( 616 f"A tool named '{tool_name}' was invoked by a model, but was not available." 617 ) 618 619 # Parse the arguments and validate them against the tool's schema 620 try: 621 parsed_args = json.loads(tool_call.function.arguments) 622 except json.JSONDecodeError: 623 raise RuntimeError( 624 f"Failed to parse arguments for tool '{tool_name}' (should be JSON): {tool_call.function.arguments}" 625 ) 626 try: 627 tool_call_definition = await tool.toolcall_definition() 628 json_schema = json.dumps(tool_call_definition["function"]["parameters"]) 629 validate_schema_with_value_error(parsed_args, json_schema) 630 except Exception as e: 631 raise RuntimeError( 632 f"Failed to validate arguments for tool '{tool_name}'. The arguments didn't match the tool's schema. The arguments were: {parsed_args}\n The error was: {e}" 633 ) from e 634 635 # Create context with the calling task's allow_saving setting 636 context = ToolCallContext( 637 allow_saving=self.base_adapter_config.allow_saving 638 ) 639 result = await tool.run(context, **parsed_args) 640 641 tool_call_response_messages.append( 642 ChatCompletionToolMessageParamWrapper( 643 role="tool", 644 tool_call_id=tool_call.id, 645 content=result.output, 646 kiln_task_tool_data=result.kiln_task_tool_data 647 if isinstance(result, KilnTaskToolResult) 648 else None, 649 ) 650 ) 651 652 if ( 653 assistant_output_from_toolcall is not None 654 and len(tool_call_response_messages) > 0 655 ): 656 raise RuntimeError( 657 "Model asked for impossible combination: task_response tool call and other tool calls were both provided in the same turn. This is not supported as it means the model asked us to both return task_response results (ending the turn) and run new tools calls to send back to the model. If the model makes this mistake often, try a difference structured data model like JSON schema, where this is impossible." 658 ) 659 660 return assistant_output_from_toolcall, tool_call_response_messages
def
litellm_message_to_trace_message( self, raw_message: litellm.types.utils.Message) -> kiln_ai.utils.open_ai_types.ChatCompletionAssistantMessageParamWrapper:
662 def litellm_message_to_trace_message( 663 self, raw_message: LiteLLMMessage 664 ) -> ChatCompletionAssistantMessageParamWrapper: 665 """ 666 Convert a LiteLLM Message object to an OpenAI compatible message, our ChatCompletionAssistantMessageParamWrapper 667 """ 668 message: ChatCompletionAssistantMessageParamWrapper = { 669 "role": "assistant", 670 } 671 if raw_message.role != "assistant": 672 raise ValueError( 673 "Model returned a message with a role other than assistant. This is not supported." 674 ) 675 676 if hasattr(raw_message, "content"): 677 message["content"] = raw_message.content 678 if hasattr(raw_message, "reasoning_content"): 679 message["reasoning_content"] = raw_message.reasoning_content 680 if hasattr(raw_message, "tool_calls"): 681 # Convert ChatCompletionMessageToolCall to ChatCompletionMessageToolCallParam 682 open_ai_tool_calls: List[ChatCompletionMessageToolCallParam] = [] 683 for litellm_tool_call in raw_message.tool_calls or []: 684 # Optional in the SDK for streaming responses, but should never be None at this point. 685 if litellm_tool_call.function.name is None: 686 raise ValueError( 687 "The model requested a tool call, without providing a function name (required)." 688 ) 689 open_ai_tool_calls.append( 690 ChatCompletionMessageToolCallParam( 691 id=litellm_tool_call.id, 692 type="function", 693 function={ 694 "name": litellm_tool_call.function.name, 695 "arguments": litellm_tool_call.function.arguments, 696 }, 697 ) 698 ) 699 if len(open_ai_tool_calls) > 0: 700 message["tool_calls"] = open_ai_tool_calls 701 702 if not message.get("content") and not message.get("tool_calls"): 703 raise ValueError( 704 "Model returned an assistant message, but no content or tool calls. This is not supported." 705 ) 706 707 return message
Convert a LiteLLM Message object to an OpenAI compatible message, our ChatCompletionAssistantMessageParamWrapper
def
all_messages_to_trace( self, messages: list[typing.Union[openai.types.chat.chat_completion_developer_message_param.ChatCompletionDeveloperMessageParam, openai.types.chat.chat_completion_system_message_param.ChatCompletionSystemMessageParam, openai.types.chat.chat_completion_user_message_param.ChatCompletionUserMessageParam, kiln_ai.utils.open_ai_types.ChatCompletionAssistantMessageParamWrapper, kiln_ai.utils.open_ai_types.ChatCompletionToolMessageParamWrapper, openai.types.chat.chat_completion_function_message_param.ChatCompletionFunctionMessageParam, litellm.types.utils.Message]]) -> list[typing.Union[openai.types.chat.chat_completion_developer_message_param.ChatCompletionDeveloperMessageParam, openai.types.chat.chat_completion_system_message_param.ChatCompletionSystemMessageParam, openai.types.chat.chat_completion_user_message_param.ChatCompletionUserMessageParam, kiln_ai.utils.open_ai_types.ChatCompletionAssistantMessageParamWrapper, kiln_ai.utils.open_ai_types.ChatCompletionToolMessageParamWrapper, openai.types.chat.chat_completion_function_message_param.ChatCompletionFunctionMessageParam]]:
709 def all_messages_to_trace( 710 self, messages: list[ChatCompletionMessageIncludingLiteLLM] 711 ) -> list[ChatCompletionMessageParam]: 712 """ 713 Internally we allow LiteLLM Message objects, but for trace we need OpenAI compatible types. Replace LiteLLM Message objects with OpenAI compatible types. 714 """ 715 trace: list[ChatCompletionMessageParam] = [] 716 for message in messages: 717 if isinstance(message, LiteLLMMessage): 718 trace.append(self.litellm_message_to_trace_message(message)) 719 else: 720 trace.append(message) 721 return trace
Internally we allow LiteLLM Message objects, but for trace we need OpenAI compatible types. Replace LiteLLM Message objects with OpenAI compatible types.
Inherited Members
- kiln_ai.adapters.model_adapters.base_adapter.BaseAdapter
- task
- run_config
- prompt_builder
- output_schema
- input_schema
- base_adapter_config
- model_provider
- invoke
- invoke_returning_run_output
- has_structured_output
- build_prompt
- build_chat_formatter
- generate_run
- update_run_config_unknown_structured_output_mode
- available_tools