kiln_ai.adapters.model_adapters.litellm_adapter
1from typing import Any, Dict 2 3import litellm 4from litellm.types.utils import ChoiceLogprobs, Choices, ModelResponse 5 6import kiln_ai.datamodel as datamodel 7from kiln_ai.adapters.ml_model_list import ( 8 KilnModelProvider, 9 ModelProviderName, 10 StructuredOutputMode, 11) 12from kiln_ai.adapters.model_adapters.base_adapter import ( 13 COT_FINAL_ANSWER_PROMPT, 14 AdapterConfig, 15 BaseAdapter, 16 RunOutput, 17) 18from kiln_ai.adapters.model_adapters.litellm_config import ( 19 LiteLlmConfig, 20) 21from kiln_ai.datamodel import PromptGenerators, PromptId 22from kiln_ai.datamodel.task import RunConfig 23from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error 24 25 26class LiteLlmAdapter(BaseAdapter): 27 def __init__( 28 self, 29 config: LiteLlmConfig, 30 kiln_task: datamodel.Task, 31 prompt_id: PromptId | None = None, 32 base_adapter_config: AdapterConfig | None = None, 33 ): 34 self.config = config 35 self._additional_body_options = config.additional_body_options 36 self._api_base = config.base_url 37 self._headers = config.default_headers 38 self._litellm_model_id: str | None = None 39 40 run_config = RunConfig( 41 task=kiln_task, 42 model_name=config.model_name, 43 model_provider_name=config.provider_name, 44 prompt_id=prompt_id or PromptGenerators.SIMPLE, 45 ) 46 47 super().__init__( 48 run_config=run_config, 49 config=base_adapter_config, 50 ) 51 52 async def _run(self, input: Dict | str) -> RunOutput: 53 provider = self.model_provider() 54 if not provider.model_id: 55 raise ValueError("Model ID is required for OpenAI compatible models") 56 57 intermediate_outputs: dict[str, str] = {} 58 prompt = self.build_prompt() 59 user_msg = self.prompt_builder.build_user_message(input) 60 messages = [ 61 {"role": "system", "content": prompt}, 62 {"role": "user", "content": user_msg}, 63 ] 64 65 run_strategy, cot_prompt = self.run_strategy() 66 67 if run_strategy == "cot_as_message": 68 if not cot_prompt: 69 raise ValueError("cot_prompt is required for cot_as_message strategy") 70 messages.append({"role": "system", "content": cot_prompt}) 71 elif run_strategy == "cot_two_call": 72 if not cot_prompt: 73 raise ValueError("cot_prompt is required for cot_two_call strategy") 74 messages.append({"role": "system", "content": cot_prompt}) 75 76 # First call for chain of thought - No logprobs as only needed for final answer 77 completion_kwargs = await self.build_completion_kwargs( 78 provider, messages, None 79 ) 80 cot_response = await litellm.acompletion(**completion_kwargs) 81 if ( 82 not isinstance(cot_response, ModelResponse) 83 or not cot_response.choices 84 or len(cot_response.choices) == 0 85 or not isinstance(cot_response.choices[0], Choices) 86 ): 87 raise RuntimeError( 88 f"Expected ModelResponse with Choices, got {type(cot_response)}." 89 ) 90 cot_content = cot_response.choices[0].message.content 91 if cot_content is not None: 92 intermediate_outputs["chain_of_thought"] = cot_content 93 94 messages.extend( 95 [ 96 {"role": "assistant", "content": cot_content or ""}, 97 {"role": "user", "content": COT_FINAL_ANSWER_PROMPT}, 98 ] 99 ) 100 101 # Make the API call using litellm 102 completion_kwargs = await self.build_completion_kwargs( 103 provider, messages, self.base_adapter_config.top_logprobs 104 ) 105 response = await litellm.acompletion(**completion_kwargs) 106 107 if not isinstance(response, ModelResponse): 108 raise RuntimeError(f"Expected ModelResponse, got {type(response)}.") 109 110 # Maybe remove this? There is no error attribute on the response object. 111 # # Keeping in typesafe way as we added it for a reason, but should investigate what that was and if it still applies. 112 if hasattr(response, "error") and response.__getattribute__("error"): 113 raise RuntimeError( 114 f"LLM API returned an error: {response.__getattribute__('error')}" 115 ) 116 117 if ( 118 not response.choices 119 or len(response.choices) == 0 120 or not isinstance(response.choices[0], Choices) 121 ): 122 raise RuntimeError( 123 "No message content returned in the response from LLM API" 124 ) 125 126 message = response.choices[0].message 127 logprobs = ( 128 response.choices[0].logprobs 129 if hasattr(response.choices[0], "logprobs") 130 and isinstance(response.choices[0].logprobs, ChoiceLogprobs) 131 else None 132 ) 133 134 # Check logprobs worked, if requested 135 if self.base_adapter_config.top_logprobs is not None and logprobs is None: 136 raise RuntimeError("Logprobs were required, but no logprobs were returned.") 137 138 # Save reasoning if it exists and was parsed by LiteLLM (or openrouter, or anyone upstream) 139 if hasattr(message, "reasoning_content") and message.reasoning_content: 140 intermediate_outputs["reasoning"] = message.reasoning_content 141 142 # the string content of the response 143 response_content = message.content 144 145 # Fallback: Use args of first tool call to task_response if it exists 146 if ( 147 not response_content 148 and hasattr(message, "tool_calls") 149 and message.tool_calls 150 ): 151 tool_call = next( 152 ( 153 tool_call 154 for tool_call in message.tool_calls 155 if tool_call.function.name == "task_response" 156 ), 157 None, 158 ) 159 if tool_call: 160 response_content = tool_call.function.arguments 161 162 if not isinstance(response_content, str): 163 raise RuntimeError(f"response is not a string: {response_content}") 164 165 return RunOutput( 166 output=response_content, 167 intermediate_outputs=intermediate_outputs, 168 output_logprobs=logprobs, 169 ) 170 171 def adapter_name(self) -> str: 172 return "kiln_openai_compatible_adapter" 173 174 async def response_format_options(self) -> dict[str, Any]: 175 # Unstructured if task isn't structured 176 if not self.has_structured_output(): 177 return {} 178 179 provider = self.model_provider() 180 match provider.structured_output_mode: 181 case StructuredOutputMode.json_mode: 182 return {"response_format": {"type": "json_object"}} 183 case StructuredOutputMode.json_schema: 184 return self.json_schema_response_format() 185 case StructuredOutputMode.function_calling_weak: 186 return self.tool_call_params(strict=False) 187 case StructuredOutputMode.function_calling: 188 return self.tool_call_params(strict=True) 189 case StructuredOutputMode.json_instructions: 190 # JSON instructions dynamically injected in prompt, not the API response format. Do not ask for json_object (see option below). 191 return {} 192 case StructuredOutputMode.json_custom_instructions: 193 # JSON instructions statically injected in system prompt, not the API response format. Do not ask for json_object (see option above). 194 return {} 195 case StructuredOutputMode.json_instruction_and_object: 196 # We set response_format to json_object and also set json instructions in the prompt 197 return {"response_format": {"type": "json_object"}} 198 case StructuredOutputMode.default: 199 if provider.name == ModelProviderName.ollama: 200 # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs 201 return self.json_schema_response_format() 202 else: 203 # Default to function calling -- it's older than the other modes. Higher compatibility. 204 # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI. 205 strict = provider.name == ModelProviderName.openai 206 return self.tool_call_params(strict=strict) 207 case _: 208 raise_exhaustive_enum_error(provider.structured_output_mode) 209 210 def json_schema_response_format(self) -> dict[str, Any]: 211 output_schema = self.task().output_schema() 212 return { 213 "response_format": { 214 "type": "json_schema", 215 "json_schema": { 216 "name": "task_response", 217 "schema": output_schema, 218 }, 219 } 220 } 221 222 def tool_call_params(self, strict: bool) -> dict[str, Any]: 223 # Add additional_properties: false to the schema (OpenAI requires this for some models) 224 output_schema = self.task().output_schema() 225 if not isinstance(output_schema, dict): 226 raise ValueError( 227 "Invalid output schema for this task. Can not use tool calls." 228 ) 229 output_schema["additionalProperties"] = False 230 231 function_params = { 232 "name": "task_response", 233 "parameters": output_schema, 234 } 235 # This should be on, but we allow setting function_calling_weak for APIs that don't support it. 236 if strict: 237 function_params["strict"] = True 238 239 return { 240 "tools": [ 241 { 242 "type": "function", 243 "function": function_params, 244 } 245 ], 246 "tool_choice": { 247 "type": "function", 248 "function": {"name": "task_response"}, 249 }, 250 } 251 252 def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]: 253 # TODO P1: Don't love having this logic here. But it's a usability improvement 254 # so better to keep it than exclude it. Should figure out how I want to isolate 255 # this sort of logic so it's config driven and can be overridden 256 257 extra_body = {} 258 provider_options = {} 259 260 if provider.thinking_level is not None: 261 extra_body["reasoning_effort"] = provider.thinking_level 262 263 if provider.require_openrouter_reasoning: 264 # https://openrouter.ai/docs/use-cases/reasoning-tokens 265 extra_body["reasoning"] = { 266 "exclude": False, 267 } 268 269 if provider.anthropic_extended_thinking: 270 extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000} 271 272 if provider.r1_openrouter_options: 273 # Require providers that support the reasoning parameter 274 provider_options["require_parameters"] = True 275 # Prefer R1 providers with reasonable perf/quants 276 provider_options["order"] = ["Fireworks", "Together"] 277 # R1 providers with unreasonable quants 278 provider_options["ignore"] = ["DeepInfra"] 279 280 # Only set of this request is to get logprobs. 281 if ( 282 provider.logprobs_openrouter_options 283 and self.base_adapter_config.top_logprobs is not None 284 ): 285 # Don't let OpenRouter choose a provider that doesn't support logprobs. 286 provider_options["require_parameters"] = True 287 # DeepInfra silently fails to return logprobs consistently. 288 provider_options["ignore"] = ["DeepInfra"] 289 290 if provider.openrouter_skip_required_parameters: 291 # Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params. 292 provider_options["require_parameters"] = False 293 294 if len(provider_options) > 0: 295 extra_body["provider"] = provider_options 296 297 return extra_body 298 299 def litellm_model_id(self) -> str: 300 # The model ID is an interesting combination of format and url endpoint. 301 # It specifics the provider URL/host, but this is overridden if you manually set an api url 302 303 if self._litellm_model_id: 304 return self._litellm_model_id 305 306 provider = self.model_provider() 307 if not provider.model_id: 308 raise ValueError("Model ID is required for OpenAI compatible models") 309 310 litellm_provider_name: str | None = None 311 is_custom = False 312 match provider.name: 313 case ModelProviderName.openrouter: 314 litellm_provider_name = "openrouter" 315 case ModelProviderName.openai: 316 litellm_provider_name = "openai" 317 case ModelProviderName.groq: 318 litellm_provider_name = "groq" 319 case ModelProviderName.anthropic: 320 litellm_provider_name = "anthropic" 321 case ModelProviderName.ollama: 322 # We don't let litellm use the Ollama API and muck with our requests. We use Ollama's OpenAI compatible API. 323 # This is because we're setting detailed features like response_format=json_schema and want lower level control. 324 is_custom = True 325 case ModelProviderName.gemini_api: 326 litellm_provider_name = "gemini" 327 case ModelProviderName.fireworks_ai: 328 litellm_provider_name = "fireworks_ai" 329 case ModelProviderName.amazon_bedrock: 330 litellm_provider_name = "bedrock" 331 case ModelProviderName.azure_openai: 332 litellm_provider_name = "azure" 333 case ModelProviderName.huggingface: 334 litellm_provider_name = "huggingface" 335 case ModelProviderName.vertex: 336 litellm_provider_name = "vertex_ai" 337 case ModelProviderName.together_ai: 338 litellm_provider_name = "together_ai" 339 case ModelProviderName.openai_compatible: 340 is_custom = True 341 case ModelProviderName.kiln_custom_registry: 342 is_custom = True 343 case ModelProviderName.kiln_fine_tune: 344 is_custom = True 345 case _: 346 raise_exhaustive_enum_error(provider.name) 347 348 if is_custom: 349 if self._api_base is None: 350 raise ValueError( 351 "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)" 352 ) 353 # Use openai as it's only used for format, not url 354 litellm_provider_name = "openai" 355 356 # Sholdn't be possible but keep type checker happy 357 if litellm_provider_name is None: 358 raise ValueError( 359 f"Provider name could not lookup valid litellm provider ID {provider.model_id}" 360 ) 361 362 self._litellm_model_id = litellm_provider_name + "/" + provider.model_id 363 return self._litellm_model_id 364 365 async def build_completion_kwargs( 366 self, 367 provider: KilnModelProvider, 368 messages: list[dict[str, Any]], 369 top_logprobs: int | None, 370 ) -> dict[str, Any]: 371 extra_body = self.build_extra_body(provider) 372 373 # Merge all parameters into a single kwargs dict for litellm 374 completion_kwargs = { 375 "model": self.litellm_model_id(), 376 "messages": messages, 377 "api_base": self._api_base, 378 "headers": self._headers, 379 **extra_body, 380 **self._additional_body_options, 381 } 382 383 # Response format: json_schema, json_instructions, json_mode, function_calling, etc 384 response_format_options = await self.response_format_options() 385 completion_kwargs.update(response_format_options) 386 387 if top_logprobs is not None: 388 completion_kwargs["logprobs"] = True 389 completion_kwargs["top_logprobs"] = top_logprobs 390 391 return completion_kwargs
27class LiteLlmAdapter(BaseAdapter): 28 def __init__( 29 self, 30 config: LiteLlmConfig, 31 kiln_task: datamodel.Task, 32 prompt_id: PromptId | None = None, 33 base_adapter_config: AdapterConfig | None = None, 34 ): 35 self.config = config 36 self._additional_body_options = config.additional_body_options 37 self._api_base = config.base_url 38 self._headers = config.default_headers 39 self._litellm_model_id: str | None = None 40 41 run_config = RunConfig( 42 task=kiln_task, 43 model_name=config.model_name, 44 model_provider_name=config.provider_name, 45 prompt_id=prompt_id or PromptGenerators.SIMPLE, 46 ) 47 48 super().__init__( 49 run_config=run_config, 50 config=base_adapter_config, 51 ) 52 53 async def _run(self, input: Dict | str) -> RunOutput: 54 provider = self.model_provider() 55 if not provider.model_id: 56 raise ValueError("Model ID is required for OpenAI compatible models") 57 58 intermediate_outputs: dict[str, str] = {} 59 prompt = self.build_prompt() 60 user_msg = self.prompt_builder.build_user_message(input) 61 messages = [ 62 {"role": "system", "content": prompt}, 63 {"role": "user", "content": user_msg}, 64 ] 65 66 run_strategy, cot_prompt = self.run_strategy() 67 68 if run_strategy == "cot_as_message": 69 if not cot_prompt: 70 raise ValueError("cot_prompt is required for cot_as_message strategy") 71 messages.append({"role": "system", "content": cot_prompt}) 72 elif run_strategy == "cot_two_call": 73 if not cot_prompt: 74 raise ValueError("cot_prompt is required for cot_two_call strategy") 75 messages.append({"role": "system", "content": cot_prompt}) 76 77 # First call for chain of thought - No logprobs as only needed for final answer 78 completion_kwargs = await self.build_completion_kwargs( 79 provider, messages, None 80 ) 81 cot_response = await litellm.acompletion(**completion_kwargs) 82 if ( 83 not isinstance(cot_response, ModelResponse) 84 or not cot_response.choices 85 or len(cot_response.choices) == 0 86 or not isinstance(cot_response.choices[0], Choices) 87 ): 88 raise RuntimeError( 89 f"Expected ModelResponse with Choices, got {type(cot_response)}." 90 ) 91 cot_content = cot_response.choices[0].message.content 92 if cot_content is not None: 93 intermediate_outputs["chain_of_thought"] = cot_content 94 95 messages.extend( 96 [ 97 {"role": "assistant", "content": cot_content or ""}, 98 {"role": "user", "content": COT_FINAL_ANSWER_PROMPT}, 99 ] 100 ) 101 102 # Make the API call using litellm 103 completion_kwargs = await self.build_completion_kwargs( 104 provider, messages, self.base_adapter_config.top_logprobs 105 ) 106 response = await litellm.acompletion(**completion_kwargs) 107 108 if not isinstance(response, ModelResponse): 109 raise RuntimeError(f"Expected ModelResponse, got {type(response)}.") 110 111 # Maybe remove this? There is no error attribute on the response object. 112 # # Keeping in typesafe way as we added it for a reason, but should investigate what that was and if it still applies. 113 if hasattr(response, "error") and response.__getattribute__("error"): 114 raise RuntimeError( 115 f"LLM API returned an error: {response.__getattribute__('error')}" 116 ) 117 118 if ( 119 not response.choices 120 or len(response.choices) == 0 121 or not isinstance(response.choices[0], Choices) 122 ): 123 raise RuntimeError( 124 "No message content returned in the response from LLM API" 125 ) 126 127 message = response.choices[0].message 128 logprobs = ( 129 response.choices[0].logprobs 130 if hasattr(response.choices[0], "logprobs") 131 and isinstance(response.choices[0].logprobs, ChoiceLogprobs) 132 else None 133 ) 134 135 # Check logprobs worked, if requested 136 if self.base_adapter_config.top_logprobs is not None and logprobs is None: 137 raise RuntimeError("Logprobs were required, but no logprobs were returned.") 138 139 # Save reasoning if it exists and was parsed by LiteLLM (or openrouter, or anyone upstream) 140 if hasattr(message, "reasoning_content") and message.reasoning_content: 141 intermediate_outputs["reasoning"] = message.reasoning_content 142 143 # the string content of the response 144 response_content = message.content 145 146 # Fallback: Use args of first tool call to task_response if it exists 147 if ( 148 not response_content 149 and hasattr(message, "tool_calls") 150 and message.tool_calls 151 ): 152 tool_call = next( 153 ( 154 tool_call 155 for tool_call in message.tool_calls 156 if tool_call.function.name == "task_response" 157 ), 158 None, 159 ) 160 if tool_call: 161 response_content = tool_call.function.arguments 162 163 if not isinstance(response_content, str): 164 raise RuntimeError(f"response is not a string: {response_content}") 165 166 return RunOutput( 167 output=response_content, 168 intermediate_outputs=intermediate_outputs, 169 output_logprobs=logprobs, 170 ) 171 172 def adapter_name(self) -> str: 173 return "kiln_openai_compatible_adapter" 174 175 async def response_format_options(self) -> dict[str, Any]: 176 # Unstructured if task isn't structured 177 if not self.has_structured_output(): 178 return {} 179 180 provider = self.model_provider() 181 match provider.structured_output_mode: 182 case StructuredOutputMode.json_mode: 183 return {"response_format": {"type": "json_object"}} 184 case StructuredOutputMode.json_schema: 185 return self.json_schema_response_format() 186 case StructuredOutputMode.function_calling_weak: 187 return self.tool_call_params(strict=False) 188 case StructuredOutputMode.function_calling: 189 return self.tool_call_params(strict=True) 190 case StructuredOutputMode.json_instructions: 191 # JSON instructions dynamically injected in prompt, not the API response format. Do not ask for json_object (see option below). 192 return {} 193 case StructuredOutputMode.json_custom_instructions: 194 # JSON instructions statically injected in system prompt, not the API response format. Do not ask for json_object (see option above). 195 return {} 196 case StructuredOutputMode.json_instruction_and_object: 197 # We set response_format to json_object and also set json instructions in the prompt 198 return {"response_format": {"type": "json_object"}} 199 case StructuredOutputMode.default: 200 if provider.name == ModelProviderName.ollama: 201 # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs 202 return self.json_schema_response_format() 203 else: 204 # Default to function calling -- it's older than the other modes. Higher compatibility. 205 # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI. 206 strict = provider.name == ModelProviderName.openai 207 return self.tool_call_params(strict=strict) 208 case _: 209 raise_exhaustive_enum_error(provider.structured_output_mode) 210 211 def json_schema_response_format(self) -> dict[str, Any]: 212 output_schema = self.task().output_schema() 213 return { 214 "response_format": { 215 "type": "json_schema", 216 "json_schema": { 217 "name": "task_response", 218 "schema": output_schema, 219 }, 220 } 221 } 222 223 def tool_call_params(self, strict: bool) -> dict[str, Any]: 224 # Add additional_properties: false to the schema (OpenAI requires this for some models) 225 output_schema = self.task().output_schema() 226 if not isinstance(output_schema, dict): 227 raise ValueError( 228 "Invalid output schema for this task. Can not use tool calls." 229 ) 230 output_schema["additionalProperties"] = False 231 232 function_params = { 233 "name": "task_response", 234 "parameters": output_schema, 235 } 236 # This should be on, but we allow setting function_calling_weak for APIs that don't support it. 237 if strict: 238 function_params["strict"] = True 239 240 return { 241 "tools": [ 242 { 243 "type": "function", 244 "function": function_params, 245 } 246 ], 247 "tool_choice": { 248 "type": "function", 249 "function": {"name": "task_response"}, 250 }, 251 } 252 253 def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]: 254 # TODO P1: Don't love having this logic here. But it's a usability improvement 255 # so better to keep it than exclude it. Should figure out how I want to isolate 256 # this sort of logic so it's config driven and can be overridden 257 258 extra_body = {} 259 provider_options = {} 260 261 if provider.thinking_level is not None: 262 extra_body["reasoning_effort"] = provider.thinking_level 263 264 if provider.require_openrouter_reasoning: 265 # https://openrouter.ai/docs/use-cases/reasoning-tokens 266 extra_body["reasoning"] = { 267 "exclude": False, 268 } 269 270 if provider.anthropic_extended_thinking: 271 extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000} 272 273 if provider.r1_openrouter_options: 274 # Require providers that support the reasoning parameter 275 provider_options["require_parameters"] = True 276 # Prefer R1 providers with reasonable perf/quants 277 provider_options["order"] = ["Fireworks", "Together"] 278 # R1 providers with unreasonable quants 279 provider_options["ignore"] = ["DeepInfra"] 280 281 # Only set of this request is to get logprobs. 282 if ( 283 provider.logprobs_openrouter_options 284 and self.base_adapter_config.top_logprobs is not None 285 ): 286 # Don't let OpenRouter choose a provider that doesn't support logprobs. 287 provider_options["require_parameters"] = True 288 # DeepInfra silently fails to return logprobs consistently. 289 provider_options["ignore"] = ["DeepInfra"] 290 291 if provider.openrouter_skip_required_parameters: 292 # Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params. 293 provider_options["require_parameters"] = False 294 295 if len(provider_options) > 0: 296 extra_body["provider"] = provider_options 297 298 return extra_body 299 300 def litellm_model_id(self) -> str: 301 # The model ID is an interesting combination of format and url endpoint. 302 # It specifics the provider URL/host, but this is overridden if you manually set an api url 303 304 if self._litellm_model_id: 305 return self._litellm_model_id 306 307 provider = self.model_provider() 308 if not provider.model_id: 309 raise ValueError("Model ID is required for OpenAI compatible models") 310 311 litellm_provider_name: str | None = None 312 is_custom = False 313 match provider.name: 314 case ModelProviderName.openrouter: 315 litellm_provider_name = "openrouter" 316 case ModelProviderName.openai: 317 litellm_provider_name = "openai" 318 case ModelProviderName.groq: 319 litellm_provider_name = "groq" 320 case ModelProviderName.anthropic: 321 litellm_provider_name = "anthropic" 322 case ModelProviderName.ollama: 323 # We don't let litellm use the Ollama API and muck with our requests. We use Ollama's OpenAI compatible API. 324 # This is because we're setting detailed features like response_format=json_schema and want lower level control. 325 is_custom = True 326 case ModelProviderName.gemini_api: 327 litellm_provider_name = "gemini" 328 case ModelProviderName.fireworks_ai: 329 litellm_provider_name = "fireworks_ai" 330 case ModelProviderName.amazon_bedrock: 331 litellm_provider_name = "bedrock" 332 case ModelProviderName.azure_openai: 333 litellm_provider_name = "azure" 334 case ModelProviderName.huggingface: 335 litellm_provider_name = "huggingface" 336 case ModelProviderName.vertex: 337 litellm_provider_name = "vertex_ai" 338 case ModelProviderName.together_ai: 339 litellm_provider_name = "together_ai" 340 case ModelProviderName.openai_compatible: 341 is_custom = True 342 case ModelProviderName.kiln_custom_registry: 343 is_custom = True 344 case ModelProviderName.kiln_fine_tune: 345 is_custom = True 346 case _: 347 raise_exhaustive_enum_error(provider.name) 348 349 if is_custom: 350 if self._api_base is None: 351 raise ValueError( 352 "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)" 353 ) 354 # Use openai as it's only used for format, not url 355 litellm_provider_name = "openai" 356 357 # Sholdn't be possible but keep type checker happy 358 if litellm_provider_name is None: 359 raise ValueError( 360 f"Provider name could not lookup valid litellm provider ID {provider.model_id}" 361 ) 362 363 self._litellm_model_id = litellm_provider_name + "/" + provider.model_id 364 return self._litellm_model_id 365 366 async def build_completion_kwargs( 367 self, 368 provider: KilnModelProvider, 369 messages: list[dict[str, Any]], 370 top_logprobs: int | None, 371 ) -> dict[str, Any]: 372 extra_body = self.build_extra_body(provider) 373 374 # Merge all parameters into a single kwargs dict for litellm 375 completion_kwargs = { 376 "model": self.litellm_model_id(), 377 "messages": messages, 378 "api_base": self._api_base, 379 "headers": self._headers, 380 **extra_body, 381 **self._additional_body_options, 382 } 383 384 # Response format: json_schema, json_instructions, json_mode, function_calling, etc 385 response_format_options = await self.response_format_options() 386 completion_kwargs.update(response_format_options) 387 388 if top_logprobs is not None: 389 completion_kwargs["logprobs"] = True 390 completion_kwargs["top_logprobs"] = top_logprobs 391 392 return completion_kwargs
Base class for AI model adapters that handle task execution.
This abstract class provides the foundation for implementing model-specific adapters that can process tasks with structured or unstructured inputs/outputs. It handles input/output validation, prompt building, and run tracking.
Attributes: prompt_builder (BasePromptBuilder): Builder for constructing prompts for the model kiln_task (Task): The task configuration and metadata output_schema (dict | None): JSON schema for validating structured outputs input_schema (dict | None): JSON schema for validating structured inputs
LiteLlmAdapter( config: kiln_ai.adapters.model_adapters.litellm_config.LiteLlmConfig, kiln_task: kiln_ai.datamodel.Task, prompt_id: Optional[Annotated[str, AfterValidator(func=<function <lambda>>)]] = None, base_adapter_config: kiln_ai.adapters.model_adapters.base_adapter.AdapterConfig | None = None)
28 def __init__( 29 self, 30 config: LiteLlmConfig, 31 kiln_task: datamodel.Task, 32 prompt_id: PromptId | None = None, 33 base_adapter_config: AdapterConfig | None = None, 34 ): 35 self.config = config 36 self._additional_body_options = config.additional_body_options 37 self._api_base = config.base_url 38 self._headers = config.default_headers 39 self._litellm_model_id: str | None = None 40 41 run_config = RunConfig( 42 task=kiln_task, 43 model_name=config.model_name, 44 model_provider_name=config.provider_name, 45 prompt_id=prompt_id or PromptGenerators.SIMPLE, 46 ) 47 48 super().__init__( 49 run_config=run_config, 50 config=base_adapter_config, 51 )
async def
response_format_options(self) -> dict[str, typing.Any]:
175 async def response_format_options(self) -> dict[str, Any]: 176 # Unstructured if task isn't structured 177 if not self.has_structured_output(): 178 return {} 179 180 provider = self.model_provider() 181 match provider.structured_output_mode: 182 case StructuredOutputMode.json_mode: 183 return {"response_format": {"type": "json_object"}} 184 case StructuredOutputMode.json_schema: 185 return self.json_schema_response_format() 186 case StructuredOutputMode.function_calling_weak: 187 return self.tool_call_params(strict=False) 188 case StructuredOutputMode.function_calling: 189 return self.tool_call_params(strict=True) 190 case StructuredOutputMode.json_instructions: 191 # JSON instructions dynamically injected in prompt, not the API response format. Do not ask for json_object (see option below). 192 return {} 193 case StructuredOutputMode.json_custom_instructions: 194 # JSON instructions statically injected in system prompt, not the API response format. Do not ask for json_object (see option above). 195 return {} 196 case StructuredOutputMode.json_instruction_and_object: 197 # We set response_format to json_object and also set json instructions in the prompt 198 return {"response_format": {"type": "json_object"}} 199 case StructuredOutputMode.default: 200 if provider.name == ModelProviderName.ollama: 201 # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs 202 return self.json_schema_response_format() 203 else: 204 # Default to function calling -- it's older than the other modes. Higher compatibility. 205 # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI. 206 strict = provider.name == ModelProviderName.openai 207 return self.tool_call_params(strict=strict) 208 case _: 209 raise_exhaustive_enum_error(provider.structured_output_mode)
def
tool_call_params(self, strict: bool) -> dict[str, typing.Any]:
223 def tool_call_params(self, strict: bool) -> dict[str, Any]: 224 # Add additional_properties: false to the schema (OpenAI requires this for some models) 225 output_schema = self.task().output_schema() 226 if not isinstance(output_schema, dict): 227 raise ValueError( 228 "Invalid output schema for this task. Can not use tool calls." 229 ) 230 output_schema["additionalProperties"] = False 231 232 function_params = { 233 "name": "task_response", 234 "parameters": output_schema, 235 } 236 # This should be on, but we allow setting function_calling_weak for APIs that don't support it. 237 if strict: 238 function_params["strict"] = True 239 240 return { 241 "tools": [ 242 { 243 "type": "function", 244 "function": function_params, 245 } 246 ], 247 "tool_choice": { 248 "type": "function", 249 "function": {"name": "task_response"}, 250 }, 251 }
def
build_extra_body( self, provider: kiln_ai.adapters.ml_model_list.KilnModelProvider) -> dict[str, typing.Any]:
253 def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]: 254 # TODO P1: Don't love having this logic here. But it's a usability improvement 255 # so better to keep it than exclude it. Should figure out how I want to isolate 256 # this sort of logic so it's config driven and can be overridden 257 258 extra_body = {} 259 provider_options = {} 260 261 if provider.thinking_level is not None: 262 extra_body["reasoning_effort"] = provider.thinking_level 263 264 if provider.require_openrouter_reasoning: 265 # https://openrouter.ai/docs/use-cases/reasoning-tokens 266 extra_body["reasoning"] = { 267 "exclude": False, 268 } 269 270 if provider.anthropic_extended_thinking: 271 extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000} 272 273 if provider.r1_openrouter_options: 274 # Require providers that support the reasoning parameter 275 provider_options["require_parameters"] = True 276 # Prefer R1 providers with reasonable perf/quants 277 provider_options["order"] = ["Fireworks", "Together"] 278 # R1 providers with unreasonable quants 279 provider_options["ignore"] = ["DeepInfra"] 280 281 # Only set of this request is to get logprobs. 282 if ( 283 provider.logprobs_openrouter_options 284 and self.base_adapter_config.top_logprobs is not None 285 ): 286 # Don't let OpenRouter choose a provider that doesn't support logprobs. 287 provider_options["require_parameters"] = True 288 # DeepInfra silently fails to return logprobs consistently. 289 provider_options["ignore"] = ["DeepInfra"] 290 291 if provider.openrouter_skip_required_parameters: 292 # Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params. 293 provider_options["require_parameters"] = False 294 295 if len(provider_options) > 0: 296 extra_body["provider"] = provider_options 297 298 return extra_body
def
litellm_model_id(self) -> str:
300 def litellm_model_id(self) -> str: 301 # The model ID is an interesting combination of format and url endpoint. 302 # It specifics the provider URL/host, but this is overridden if you manually set an api url 303 304 if self._litellm_model_id: 305 return self._litellm_model_id 306 307 provider = self.model_provider() 308 if not provider.model_id: 309 raise ValueError("Model ID is required for OpenAI compatible models") 310 311 litellm_provider_name: str | None = None 312 is_custom = False 313 match provider.name: 314 case ModelProviderName.openrouter: 315 litellm_provider_name = "openrouter" 316 case ModelProviderName.openai: 317 litellm_provider_name = "openai" 318 case ModelProviderName.groq: 319 litellm_provider_name = "groq" 320 case ModelProviderName.anthropic: 321 litellm_provider_name = "anthropic" 322 case ModelProviderName.ollama: 323 # We don't let litellm use the Ollama API and muck with our requests. We use Ollama's OpenAI compatible API. 324 # This is because we're setting detailed features like response_format=json_schema and want lower level control. 325 is_custom = True 326 case ModelProviderName.gemini_api: 327 litellm_provider_name = "gemini" 328 case ModelProviderName.fireworks_ai: 329 litellm_provider_name = "fireworks_ai" 330 case ModelProviderName.amazon_bedrock: 331 litellm_provider_name = "bedrock" 332 case ModelProviderName.azure_openai: 333 litellm_provider_name = "azure" 334 case ModelProviderName.huggingface: 335 litellm_provider_name = "huggingface" 336 case ModelProviderName.vertex: 337 litellm_provider_name = "vertex_ai" 338 case ModelProviderName.together_ai: 339 litellm_provider_name = "together_ai" 340 case ModelProviderName.openai_compatible: 341 is_custom = True 342 case ModelProviderName.kiln_custom_registry: 343 is_custom = True 344 case ModelProviderName.kiln_fine_tune: 345 is_custom = True 346 case _: 347 raise_exhaustive_enum_error(provider.name) 348 349 if is_custom: 350 if self._api_base is None: 351 raise ValueError( 352 "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)" 353 ) 354 # Use openai as it's only used for format, not url 355 litellm_provider_name = "openai" 356 357 # Sholdn't be possible but keep type checker happy 358 if litellm_provider_name is None: 359 raise ValueError( 360 f"Provider name could not lookup valid litellm provider ID {provider.model_id}" 361 ) 362 363 self._litellm_model_id = litellm_provider_name + "/" + provider.model_id 364 return self._litellm_model_id
async def
build_completion_kwargs( self, provider: kiln_ai.adapters.ml_model_list.KilnModelProvider, messages: list[dict[str, typing.Any]], top_logprobs: int | None) -> dict[str, typing.Any]:
366 async def build_completion_kwargs( 367 self, 368 provider: KilnModelProvider, 369 messages: list[dict[str, Any]], 370 top_logprobs: int | None, 371 ) -> dict[str, Any]: 372 extra_body = self.build_extra_body(provider) 373 374 # Merge all parameters into a single kwargs dict for litellm 375 completion_kwargs = { 376 "model": self.litellm_model_id(), 377 "messages": messages, 378 "api_base": self._api_base, 379 "headers": self._headers, 380 **extra_body, 381 **self._additional_body_options, 382 } 383 384 # Response format: json_schema, json_instructions, json_mode, function_calling, etc 385 response_format_options = await self.response_format_options() 386 completion_kwargs.update(response_format_options) 387 388 if top_logprobs is not None: 389 completion_kwargs["logprobs"] = True 390 completion_kwargs["top_logprobs"] = top_logprobs 391 392 return completion_kwargs