kiln_ai.adapters.model_adapters.litellm_adapter

  1from typing import Any, Dict
  2
  3import litellm
  4from litellm.types.utils import ChoiceLogprobs, Choices, ModelResponse
  5
  6import kiln_ai.datamodel as datamodel
  7from kiln_ai.adapters.ml_model_list import (
  8    KilnModelProvider,
  9    ModelProviderName,
 10    StructuredOutputMode,
 11)
 12from kiln_ai.adapters.model_adapters.base_adapter import (
 13    COT_FINAL_ANSWER_PROMPT,
 14    AdapterConfig,
 15    BaseAdapter,
 16    RunOutput,
 17)
 18from kiln_ai.adapters.model_adapters.litellm_config import (
 19    LiteLlmConfig,
 20)
 21from kiln_ai.datamodel import PromptGenerators, PromptId
 22from kiln_ai.datamodel.task import RunConfig
 23from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
 24
 25
 26class LiteLlmAdapter(BaseAdapter):
 27    def __init__(
 28        self,
 29        config: LiteLlmConfig,
 30        kiln_task: datamodel.Task,
 31        prompt_id: PromptId | None = None,
 32        base_adapter_config: AdapterConfig | None = None,
 33    ):
 34        self.config = config
 35        self._additional_body_options = config.additional_body_options
 36        self._api_base = config.base_url
 37        self._headers = config.default_headers
 38        self._litellm_model_id: str | None = None
 39
 40        run_config = RunConfig(
 41            task=kiln_task,
 42            model_name=config.model_name,
 43            model_provider_name=config.provider_name,
 44            prompt_id=prompt_id or PromptGenerators.SIMPLE,
 45        )
 46
 47        super().__init__(
 48            run_config=run_config,
 49            config=base_adapter_config,
 50        )
 51
 52    async def _run(self, input: Dict | str) -> RunOutput:
 53        provider = self.model_provider()
 54        if not provider.model_id:
 55            raise ValueError("Model ID is required for OpenAI compatible models")
 56
 57        intermediate_outputs: dict[str, str] = {}
 58        prompt = self.build_prompt()
 59        user_msg = self.prompt_builder.build_user_message(input)
 60        messages = [
 61            {"role": "system", "content": prompt},
 62            {"role": "user", "content": user_msg},
 63        ]
 64
 65        run_strategy, cot_prompt = self.run_strategy()
 66
 67        if run_strategy == "cot_as_message":
 68            if not cot_prompt:
 69                raise ValueError("cot_prompt is required for cot_as_message strategy")
 70            messages.append({"role": "system", "content": cot_prompt})
 71        elif run_strategy == "cot_two_call":
 72            if not cot_prompt:
 73                raise ValueError("cot_prompt is required for cot_two_call strategy")
 74            messages.append({"role": "system", "content": cot_prompt})
 75
 76            # First call for chain of thought - No logprobs as only needed for final answer
 77            completion_kwargs = await self.build_completion_kwargs(
 78                provider, messages, None
 79            )
 80            cot_response = await litellm.acompletion(**completion_kwargs)
 81            if (
 82                not isinstance(cot_response, ModelResponse)
 83                or not cot_response.choices
 84                or len(cot_response.choices) == 0
 85                or not isinstance(cot_response.choices[0], Choices)
 86            ):
 87                raise RuntimeError(
 88                    f"Expected ModelResponse with Choices, got {type(cot_response)}."
 89                )
 90            cot_content = cot_response.choices[0].message.content
 91            if cot_content is not None:
 92                intermediate_outputs["chain_of_thought"] = cot_content
 93
 94            messages.extend(
 95                [
 96                    {"role": "assistant", "content": cot_content or ""},
 97                    {"role": "user", "content": COT_FINAL_ANSWER_PROMPT},
 98                ]
 99            )
100
101        # Make the API call using litellm
102        completion_kwargs = await self.build_completion_kwargs(
103            provider, messages, self.base_adapter_config.top_logprobs
104        )
105        response = await litellm.acompletion(**completion_kwargs)
106
107        if not isinstance(response, ModelResponse):
108            raise RuntimeError(f"Expected ModelResponse, got {type(response)}.")
109
110        # Maybe remove this? There is no error attribute on the response object.
111        # # Keeping in typesafe way as we added it for a reason, but should investigate what that was and if it still applies.
112        if hasattr(response, "error") and response.__getattribute__("error"):
113            raise RuntimeError(
114                f"LLM API returned an error: {response.__getattribute__('error')}"
115            )
116
117        if (
118            not response.choices
119            or len(response.choices) == 0
120            or not isinstance(response.choices[0], Choices)
121        ):
122            raise RuntimeError(
123                "No message content returned in the response from LLM API"
124            )
125
126        message = response.choices[0].message
127        logprobs = (
128            response.choices[0].logprobs
129            if hasattr(response.choices[0], "logprobs")
130            and isinstance(response.choices[0].logprobs, ChoiceLogprobs)
131            else None
132        )
133
134        # Check logprobs worked, if requested
135        if self.base_adapter_config.top_logprobs is not None and logprobs is None:
136            raise RuntimeError("Logprobs were required, but no logprobs were returned.")
137
138        # Save reasoning if it exists and was parsed by LiteLLM (or openrouter, or anyone upstream)
139        if hasattr(message, "reasoning_content") and message.reasoning_content:
140            intermediate_outputs["reasoning"] = message.reasoning_content
141
142        # the string content of the response
143        response_content = message.content
144
145        # Fallback: Use args of first tool call to task_response if it exists
146        if (
147            not response_content
148            and hasattr(message, "tool_calls")
149            and message.tool_calls
150        ):
151            tool_call = next(
152                (
153                    tool_call
154                    for tool_call in message.tool_calls
155                    if tool_call.function.name == "task_response"
156                ),
157                None,
158            )
159            if tool_call:
160                response_content = tool_call.function.arguments
161
162        if not isinstance(response_content, str):
163            raise RuntimeError(f"response is not a string: {response_content}")
164
165        return RunOutput(
166            output=response_content,
167            intermediate_outputs=intermediate_outputs,
168            output_logprobs=logprobs,
169        )
170
171    def adapter_name(self) -> str:
172        return "kiln_openai_compatible_adapter"
173
174    async def response_format_options(self) -> dict[str, Any]:
175        # Unstructured if task isn't structured
176        if not self.has_structured_output():
177            return {}
178
179        provider = self.model_provider()
180        match provider.structured_output_mode:
181            case StructuredOutputMode.json_mode:
182                return {"response_format": {"type": "json_object"}}
183            case StructuredOutputMode.json_schema:
184                return self.json_schema_response_format()
185            case StructuredOutputMode.function_calling_weak:
186                return self.tool_call_params(strict=False)
187            case StructuredOutputMode.function_calling:
188                return self.tool_call_params(strict=True)
189            case StructuredOutputMode.json_instructions:
190                # JSON instructions dynamically injected in prompt, not the API response format. Do not ask for json_object (see option below).
191                return {}
192            case StructuredOutputMode.json_custom_instructions:
193                # JSON instructions statically injected in system prompt, not the API response format. Do not ask for json_object (see option above).
194                return {}
195            case StructuredOutputMode.json_instruction_and_object:
196                # We set response_format to json_object and also set json instructions in the prompt
197                return {"response_format": {"type": "json_object"}}
198            case StructuredOutputMode.default:
199                if provider.name == ModelProviderName.ollama:
200                    # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs
201                    return self.json_schema_response_format()
202                else:
203                    # Default to function calling -- it's older than the other modes. Higher compatibility.
204                    # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI.
205                    strict = provider.name == ModelProviderName.openai
206                    return self.tool_call_params(strict=strict)
207            case _:
208                raise_exhaustive_enum_error(provider.structured_output_mode)
209
210    def json_schema_response_format(self) -> dict[str, Any]:
211        output_schema = self.task().output_schema()
212        return {
213            "response_format": {
214                "type": "json_schema",
215                "json_schema": {
216                    "name": "task_response",
217                    "schema": output_schema,
218                },
219            }
220        }
221
222    def tool_call_params(self, strict: bool) -> dict[str, Any]:
223        # Add additional_properties: false to the schema (OpenAI requires this for some models)
224        output_schema = self.task().output_schema()
225        if not isinstance(output_schema, dict):
226            raise ValueError(
227                "Invalid output schema for this task. Can not use tool calls."
228            )
229        output_schema["additionalProperties"] = False
230
231        function_params = {
232            "name": "task_response",
233            "parameters": output_schema,
234        }
235        # This should be on, but we allow setting function_calling_weak for APIs that don't support it.
236        if strict:
237            function_params["strict"] = True
238
239        return {
240            "tools": [
241                {
242                    "type": "function",
243                    "function": function_params,
244                }
245            ],
246            "tool_choice": {
247                "type": "function",
248                "function": {"name": "task_response"},
249            },
250        }
251
252    def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]:
253        # TODO P1: Don't love having this logic here. But it's a usability improvement
254        # so better to keep it than exclude it. Should figure out how I want to isolate
255        # this sort of logic so it's config driven and can be overridden
256
257        extra_body = {}
258        provider_options = {}
259
260        if provider.thinking_level is not None:
261            extra_body["reasoning_effort"] = provider.thinking_level
262
263        if provider.require_openrouter_reasoning:
264            # https://openrouter.ai/docs/use-cases/reasoning-tokens
265            extra_body["reasoning"] = {
266                "exclude": False,
267            }
268
269        if provider.anthropic_extended_thinking:
270            extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000}
271
272        if provider.r1_openrouter_options:
273            # Require providers that support the reasoning parameter
274            provider_options["require_parameters"] = True
275            # Prefer R1 providers with reasonable perf/quants
276            provider_options["order"] = ["Fireworks", "Together"]
277            # R1 providers with unreasonable quants
278            provider_options["ignore"] = ["DeepInfra"]
279
280        # Only set of this request is to get logprobs.
281        if (
282            provider.logprobs_openrouter_options
283            and self.base_adapter_config.top_logprobs is not None
284        ):
285            # Don't let OpenRouter choose a provider that doesn't support logprobs.
286            provider_options["require_parameters"] = True
287            # DeepInfra silently fails to return logprobs consistently.
288            provider_options["ignore"] = ["DeepInfra"]
289
290        if provider.openrouter_skip_required_parameters:
291            # Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params.
292            provider_options["require_parameters"] = False
293
294        if len(provider_options) > 0:
295            extra_body["provider"] = provider_options
296
297        return extra_body
298
299    def litellm_model_id(self) -> str:
300        # The model ID is an interesting combination of format and url endpoint.
301        # It specifics the provider URL/host, but this is overridden if you manually set an api url
302
303        if self._litellm_model_id:
304            return self._litellm_model_id
305
306        provider = self.model_provider()
307        if not provider.model_id:
308            raise ValueError("Model ID is required for OpenAI compatible models")
309
310        litellm_provider_name: str | None = None
311        is_custom = False
312        match provider.name:
313            case ModelProviderName.openrouter:
314                litellm_provider_name = "openrouter"
315            case ModelProviderName.openai:
316                litellm_provider_name = "openai"
317            case ModelProviderName.groq:
318                litellm_provider_name = "groq"
319            case ModelProviderName.anthropic:
320                litellm_provider_name = "anthropic"
321            case ModelProviderName.ollama:
322                # We don't let litellm use the Ollama API and muck with our requests. We use Ollama's OpenAI compatible API.
323                # This is because we're setting detailed features like response_format=json_schema and want lower level control.
324                is_custom = True
325            case ModelProviderName.gemini_api:
326                litellm_provider_name = "gemini"
327            case ModelProviderName.fireworks_ai:
328                litellm_provider_name = "fireworks_ai"
329            case ModelProviderName.amazon_bedrock:
330                litellm_provider_name = "bedrock"
331            case ModelProviderName.azure_openai:
332                litellm_provider_name = "azure"
333            case ModelProviderName.huggingface:
334                litellm_provider_name = "huggingface"
335            case ModelProviderName.vertex:
336                litellm_provider_name = "vertex_ai"
337            case ModelProviderName.together_ai:
338                litellm_provider_name = "together_ai"
339            case ModelProviderName.openai_compatible:
340                is_custom = True
341            case ModelProviderName.kiln_custom_registry:
342                is_custom = True
343            case ModelProviderName.kiln_fine_tune:
344                is_custom = True
345            case _:
346                raise_exhaustive_enum_error(provider.name)
347
348        if is_custom:
349            if self._api_base is None:
350                raise ValueError(
351                    "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)"
352                )
353            # Use openai as it's only used for format, not url
354            litellm_provider_name = "openai"
355
356        # Sholdn't be possible but keep type checker happy
357        if litellm_provider_name is None:
358            raise ValueError(
359                f"Provider name could not lookup valid litellm provider ID {provider.model_id}"
360            )
361
362        self._litellm_model_id = litellm_provider_name + "/" + provider.model_id
363        return self._litellm_model_id
364
365    async def build_completion_kwargs(
366        self,
367        provider: KilnModelProvider,
368        messages: list[dict[str, Any]],
369        top_logprobs: int | None,
370    ) -> dict[str, Any]:
371        extra_body = self.build_extra_body(provider)
372
373        # Merge all parameters into a single kwargs dict for litellm
374        completion_kwargs = {
375            "model": self.litellm_model_id(),
376            "messages": messages,
377            "api_base": self._api_base,
378            "headers": self._headers,
379            **extra_body,
380            **self._additional_body_options,
381        }
382
383        # Response format: json_schema, json_instructions, json_mode, function_calling, etc
384        response_format_options = await self.response_format_options()
385        completion_kwargs.update(response_format_options)
386
387        if top_logprobs is not None:
388            completion_kwargs["logprobs"] = True
389            completion_kwargs["top_logprobs"] = top_logprobs
390
391        return completion_kwargs
 27class LiteLlmAdapter(BaseAdapter):
 28    def __init__(
 29        self,
 30        config: LiteLlmConfig,
 31        kiln_task: datamodel.Task,
 32        prompt_id: PromptId | None = None,
 33        base_adapter_config: AdapterConfig | None = None,
 34    ):
 35        self.config = config
 36        self._additional_body_options = config.additional_body_options
 37        self._api_base = config.base_url
 38        self._headers = config.default_headers
 39        self._litellm_model_id: str | None = None
 40
 41        run_config = RunConfig(
 42            task=kiln_task,
 43            model_name=config.model_name,
 44            model_provider_name=config.provider_name,
 45            prompt_id=prompt_id or PromptGenerators.SIMPLE,
 46        )
 47
 48        super().__init__(
 49            run_config=run_config,
 50            config=base_adapter_config,
 51        )
 52
 53    async def _run(self, input: Dict | str) -> RunOutput:
 54        provider = self.model_provider()
 55        if not provider.model_id:
 56            raise ValueError("Model ID is required for OpenAI compatible models")
 57
 58        intermediate_outputs: dict[str, str] = {}
 59        prompt = self.build_prompt()
 60        user_msg = self.prompt_builder.build_user_message(input)
 61        messages = [
 62            {"role": "system", "content": prompt},
 63            {"role": "user", "content": user_msg},
 64        ]
 65
 66        run_strategy, cot_prompt = self.run_strategy()
 67
 68        if run_strategy == "cot_as_message":
 69            if not cot_prompt:
 70                raise ValueError("cot_prompt is required for cot_as_message strategy")
 71            messages.append({"role": "system", "content": cot_prompt})
 72        elif run_strategy == "cot_two_call":
 73            if not cot_prompt:
 74                raise ValueError("cot_prompt is required for cot_two_call strategy")
 75            messages.append({"role": "system", "content": cot_prompt})
 76
 77            # First call for chain of thought - No logprobs as only needed for final answer
 78            completion_kwargs = await self.build_completion_kwargs(
 79                provider, messages, None
 80            )
 81            cot_response = await litellm.acompletion(**completion_kwargs)
 82            if (
 83                not isinstance(cot_response, ModelResponse)
 84                or not cot_response.choices
 85                or len(cot_response.choices) == 0
 86                or not isinstance(cot_response.choices[0], Choices)
 87            ):
 88                raise RuntimeError(
 89                    f"Expected ModelResponse with Choices, got {type(cot_response)}."
 90                )
 91            cot_content = cot_response.choices[0].message.content
 92            if cot_content is not None:
 93                intermediate_outputs["chain_of_thought"] = cot_content
 94
 95            messages.extend(
 96                [
 97                    {"role": "assistant", "content": cot_content or ""},
 98                    {"role": "user", "content": COT_FINAL_ANSWER_PROMPT},
 99                ]
100            )
101
102        # Make the API call using litellm
103        completion_kwargs = await self.build_completion_kwargs(
104            provider, messages, self.base_adapter_config.top_logprobs
105        )
106        response = await litellm.acompletion(**completion_kwargs)
107
108        if not isinstance(response, ModelResponse):
109            raise RuntimeError(f"Expected ModelResponse, got {type(response)}.")
110
111        # Maybe remove this? There is no error attribute on the response object.
112        # # Keeping in typesafe way as we added it for a reason, but should investigate what that was and if it still applies.
113        if hasattr(response, "error") and response.__getattribute__("error"):
114            raise RuntimeError(
115                f"LLM API returned an error: {response.__getattribute__('error')}"
116            )
117
118        if (
119            not response.choices
120            or len(response.choices) == 0
121            or not isinstance(response.choices[0], Choices)
122        ):
123            raise RuntimeError(
124                "No message content returned in the response from LLM API"
125            )
126
127        message = response.choices[0].message
128        logprobs = (
129            response.choices[0].logprobs
130            if hasattr(response.choices[0], "logprobs")
131            and isinstance(response.choices[0].logprobs, ChoiceLogprobs)
132            else None
133        )
134
135        # Check logprobs worked, if requested
136        if self.base_adapter_config.top_logprobs is not None and logprobs is None:
137            raise RuntimeError("Logprobs were required, but no logprobs were returned.")
138
139        # Save reasoning if it exists and was parsed by LiteLLM (or openrouter, or anyone upstream)
140        if hasattr(message, "reasoning_content") and message.reasoning_content:
141            intermediate_outputs["reasoning"] = message.reasoning_content
142
143        # the string content of the response
144        response_content = message.content
145
146        # Fallback: Use args of first tool call to task_response if it exists
147        if (
148            not response_content
149            and hasattr(message, "tool_calls")
150            and message.tool_calls
151        ):
152            tool_call = next(
153                (
154                    tool_call
155                    for tool_call in message.tool_calls
156                    if tool_call.function.name == "task_response"
157                ),
158                None,
159            )
160            if tool_call:
161                response_content = tool_call.function.arguments
162
163        if not isinstance(response_content, str):
164            raise RuntimeError(f"response is not a string: {response_content}")
165
166        return RunOutput(
167            output=response_content,
168            intermediate_outputs=intermediate_outputs,
169            output_logprobs=logprobs,
170        )
171
172    def adapter_name(self) -> str:
173        return "kiln_openai_compatible_adapter"
174
175    async def response_format_options(self) -> dict[str, Any]:
176        # Unstructured if task isn't structured
177        if not self.has_structured_output():
178            return {}
179
180        provider = self.model_provider()
181        match provider.structured_output_mode:
182            case StructuredOutputMode.json_mode:
183                return {"response_format": {"type": "json_object"}}
184            case StructuredOutputMode.json_schema:
185                return self.json_schema_response_format()
186            case StructuredOutputMode.function_calling_weak:
187                return self.tool_call_params(strict=False)
188            case StructuredOutputMode.function_calling:
189                return self.tool_call_params(strict=True)
190            case StructuredOutputMode.json_instructions:
191                # JSON instructions dynamically injected in prompt, not the API response format. Do not ask for json_object (see option below).
192                return {}
193            case StructuredOutputMode.json_custom_instructions:
194                # JSON instructions statically injected in system prompt, not the API response format. Do not ask for json_object (see option above).
195                return {}
196            case StructuredOutputMode.json_instruction_and_object:
197                # We set response_format to json_object and also set json instructions in the prompt
198                return {"response_format": {"type": "json_object"}}
199            case StructuredOutputMode.default:
200                if provider.name == ModelProviderName.ollama:
201                    # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs
202                    return self.json_schema_response_format()
203                else:
204                    # Default to function calling -- it's older than the other modes. Higher compatibility.
205                    # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI.
206                    strict = provider.name == ModelProviderName.openai
207                    return self.tool_call_params(strict=strict)
208            case _:
209                raise_exhaustive_enum_error(provider.structured_output_mode)
210
211    def json_schema_response_format(self) -> dict[str, Any]:
212        output_schema = self.task().output_schema()
213        return {
214            "response_format": {
215                "type": "json_schema",
216                "json_schema": {
217                    "name": "task_response",
218                    "schema": output_schema,
219                },
220            }
221        }
222
223    def tool_call_params(self, strict: bool) -> dict[str, Any]:
224        # Add additional_properties: false to the schema (OpenAI requires this for some models)
225        output_schema = self.task().output_schema()
226        if not isinstance(output_schema, dict):
227            raise ValueError(
228                "Invalid output schema for this task. Can not use tool calls."
229            )
230        output_schema["additionalProperties"] = False
231
232        function_params = {
233            "name": "task_response",
234            "parameters": output_schema,
235        }
236        # This should be on, but we allow setting function_calling_weak for APIs that don't support it.
237        if strict:
238            function_params["strict"] = True
239
240        return {
241            "tools": [
242                {
243                    "type": "function",
244                    "function": function_params,
245                }
246            ],
247            "tool_choice": {
248                "type": "function",
249                "function": {"name": "task_response"},
250            },
251        }
252
253    def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]:
254        # TODO P1: Don't love having this logic here. But it's a usability improvement
255        # so better to keep it than exclude it. Should figure out how I want to isolate
256        # this sort of logic so it's config driven and can be overridden
257
258        extra_body = {}
259        provider_options = {}
260
261        if provider.thinking_level is not None:
262            extra_body["reasoning_effort"] = provider.thinking_level
263
264        if provider.require_openrouter_reasoning:
265            # https://openrouter.ai/docs/use-cases/reasoning-tokens
266            extra_body["reasoning"] = {
267                "exclude": False,
268            }
269
270        if provider.anthropic_extended_thinking:
271            extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000}
272
273        if provider.r1_openrouter_options:
274            # Require providers that support the reasoning parameter
275            provider_options["require_parameters"] = True
276            # Prefer R1 providers with reasonable perf/quants
277            provider_options["order"] = ["Fireworks", "Together"]
278            # R1 providers with unreasonable quants
279            provider_options["ignore"] = ["DeepInfra"]
280
281        # Only set of this request is to get logprobs.
282        if (
283            provider.logprobs_openrouter_options
284            and self.base_adapter_config.top_logprobs is not None
285        ):
286            # Don't let OpenRouter choose a provider that doesn't support logprobs.
287            provider_options["require_parameters"] = True
288            # DeepInfra silently fails to return logprobs consistently.
289            provider_options["ignore"] = ["DeepInfra"]
290
291        if provider.openrouter_skip_required_parameters:
292            # Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params.
293            provider_options["require_parameters"] = False
294
295        if len(provider_options) > 0:
296            extra_body["provider"] = provider_options
297
298        return extra_body
299
300    def litellm_model_id(self) -> str:
301        # The model ID is an interesting combination of format and url endpoint.
302        # It specifics the provider URL/host, but this is overridden if you manually set an api url
303
304        if self._litellm_model_id:
305            return self._litellm_model_id
306
307        provider = self.model_provider()
308        if not provider.model_id:
309            raise ValueError("Model ID is required for OpenAI compatible models")
310
311        litellm_provider_name: str | None = None
312        is_custom = False
313        match provider.name:
314            case ModelProviderName.openrouter:
315                litellm_provider_name = "openrouter"
316            case ModelProviderName.openai:
317                litellm_provider_name = "openai"
318            case ModelProviderName.groq:
319                litellm_provider_name = "groq"
320            case ModelProviderName.anthropic:
321                litellm_provider_name = "anthropic"
322            case ModelProviderName.ollama:
323                # We don't let litellm use the Ollama API and muck with our requests. We use Ollama's OpenAI compatible API.
324                # This is because we're setting detailed features like response_format=json_schema and want lower level control.
325                is_custom = True
326            case ModelProviderName.gemini_api:
327                litellm_provider_name = "gemini"
328            case ModelProviderName.fireworks_ai:
329                litellm_provider_name = "fireworks_ai"
330            case ModelProviderName.amazon_bedrock:
331                litellm_provider_name = "bedrock"
332            case ModelProviderName.azure_openai:
333                litellm_provider_name = "azure"
334            case ModelProviderName.huggingface:
335                litellm_provider_name = "huggingface"
336            case ModelProviderName.vertex:
337                litellm_provider_name = "vertex_ai"
338            case ModelProviderName.together_ai:
339                litellm_provider_name = "together_ai"
340            case ModelProviderName.openai_compatible:
341                is_custom = True
342            case ModelProviderName.kiln_custom_registry:
343                is_custom = True
344            case ModelProviderName.kiln_fine_tune:
345                is_custom = True
346            case _:
347                raise_exhaustive_enum_error(provider.name)
348
349        if is_custom:
350            if self._api_base is None:
351                raise ValueError(
352                    "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)"
353                )
354            # Use openai as it's only used for format, not url
355            litellm_provider_name = "openai"
356
357        # Sholdn't be possible but keep type checker happy
358        if litellm_provider_name is None:
359            raise ValueError(
360                f"Provider name could not lookup valid litellm provider ID {provider.model_id}"
361            )
362
363        self._litellm_model_id = litellm_provider_name + "/" + provider.model_id
364        return self._litellm_model_id
365
366    async def build_completion_kwargs(
367        self,
368        provider: KilnModelProvider,
369        messages: list[dict[str, Any]],
370        top_logprobs: int | None,
371    ) -> dict[str, Any]:
372        extra_body = self.build_extra_body(provider)
373
374        # Merge all parameters into a single kwargs dict for litellm
375        completion_kwargs = {
376            "model": self.litellm_model_id(),
377            "messages": messages,
378            "api_base": self._api_base,
379            "headers": self._headers,
380            **extra_body,
381            **self._additional_body_options,
382        }
383
384        # Response format: json_schema, json_instructions, json_mode, function_calling, etc
385        response_format_options = await self.response_format_options()
386        completion_kwargs.update(response_format_options)
387
388        if top_logprobs is not None:
389            completion_kwargs["logprobs"] = True
390            completion_kwargs["top_logprobs"] = top_logprobs
391
392        return completion_kwargs

Base class for AI model adapters that handle task execution.

This abstract class provides the foundation for implementing model-specific adapters that can process tasks with structured or unstructured inputs/outputs. It handles input/output validation, prompt building, and run tracking.

Attributes: prompt_builder (BasePromptBuilder): Builder for constructing prompts for the model kiln_task (Task): The task configuration and metadata output_schema (dict | None): JSON schema for validating structured outputs input_schema (dict | None): JSON schema for validating structured inputs

LiteLlmAdapter( config: kiln_ai.adapters.model_adapters.litellm_config.LiteLlmConfig, kiln_task: kiln_ai.datamodel.Task, prompt_id: Optional[Annotated[str, AfterValidator(func=<function <lambda>>)]] = None, base_adapter_config: kiln_ai.adapters.model_adapters.base_adapter.AdapterConfig | None = None)
28    def __init__(
29        self,
30        config: LiteLlmConfig,
31        kiln_task: datamodel.Task,
32        prompt_id: PromptId | None = None,
33        base_adapter_config: AdapterConfig | None = None,
34    ):
35        self.config = config
36        self._additional_body_options = config.additional_body_options
37        self._api_base = config.base_url
38        self._headers = config.default_headers
39        self._litellm_model_id: str | None = None
40
41        run_config = RunConfig(
42            task=kiln_task,
43            model_name=config.model_name,
44            model_provider_name=config.provider_name,
45            prompt_id=prompt_id or PromptGenerators.SIMPLE,
46        )
47
48        super().__init__(
49            run_config=run_config,
50            config=base_adapter_config,
51        )
config
def adapter_name(self) -> str:
172    def adapter_name(self) -> str:
173        return "kiln_openai_compatible_adapter"
async def response_format_options(self) -> dict[str, typing.Any]:
175    async def response_format_options(self) -> dict[str, Any]:
176        # Unstructured if task isn't structured
177        if not self.has_structured_output():
178            return {}
179
180        provider = self.model_provider()
181        match provider.structured_output_mode:
182            case StructuredOutputMode.json_mode:
183                return {"response_format": {"type": "json_object"}}
184            case StructuredOutputMode.json_schema:
185                return self.json_schema_response_format()
186            case StructuredOutputMode.function_calling_weak:
187                return self.tool_call_params(strict=False)
188            case StructuredOutputMode.function_calling:
189                return self.tool_call_params(strict=True)
190            case StructuredOutputMode.json_instructions:
191                # JSON instructions dynamically injected in prompt, not the API response format. Do not ask for json_object (see option below).
192                return {}
193            case StructuredOutputMode.json_custom_instructions:
194                # JSON instructions statically injected in system prompt, not the API response format. Do not ask for json_object (see option above).
195                return {}
196            case StructuredOutputMode.json_instruction_and_object:
197                # We set response_format to json_object and also set json instructions in the prompt
198                return {"response_format": {"type": "json_object"}}
199            case StructuredOutputMode.default:
200                if provider.name == ModelProviderName.ollama:
201                    # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs
202                    return self.json_schema_response_format()
203                else:
204                    # Default to function calling -- it's older than the other modes. Higher compatibility.
205                    # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI.
206                    strict = provider.name == ModelProviderName.openai
207                    return self.tool_call_params(strict=strict)
208            case _:
209                raise_exhaustive_enum_error(provider.structured_output_mode)
def json_schema_response_format(self) -> dict[str, typing.Any]:
211    def json_schema_response_format(self) -> dict[str, Any]:
212        output_schema = self.task().output_schema()
213        return {
214            "response_format": {
215                "type": "json_schema",
216                "json_schema": {
217                    "name": "task_response",
218                    "schema": output_schema,
219                },
220            }
221        }
def tool_call_params(self, strict: bool) -> dict[str, typing.Any]:
223    def tool_call_params(self, strict: bool) -> dict[str, Any]:
224        # Add additional_properties: false to the schema (OpenAI requires this for some models)
225        output_schema = self.task().output_schema()
226        if not isinstance(output_schema, dict):
227            raise ValueError(
228                "Invalid output schema for this task. Can not use tool calls."
229            )
230        output_schema["additionalProperties"] = False
231
232        function_params = {
233            "name": "task_response",
234            "parameters": output_schema,
235        }
236        # This should be on, but we allow setting function_calling_weak for APIs that don't support it.
237        if strict:
238            function_params["strict"] = True
239
240        return {
241            "tools": [
242                {
243                    "type": "function",
244                    "function": function_params,
245                }
246            ],
247            "tool_choice": {
248                "type": "function",
249                "function": {"name": "task_response"},
250            },
251        }
def build_extra_body( self, provider: kiln_ai.adapters.ml_model_list.KilnModelProvider) -> dict[str, typing.Any]:
253    def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]:
254        # TODO P1: Don't love having this logic here. But it's a usability improvement
255        # so better to keep it than exclude it. Should figure out how I want to isolate
256        # this sort of logic so it's config driven and can be overridden
257
258        extra_body = {}
259        provider_options = {}
260
261        if provider.thinking_level is not None:
262            extra_body["reasoning_effort"] = provider.thinking_level
263
264        if provider.require_openrouter_reasoning:
265            # https://openrouter.ai/docs/use-cases/reasoning-tokens
266            extra_body["reasoning"] = {
267                "exclude": False,
268            }
269
270        if provider.anthropic_extended_thinking:
271            extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000}
272
273        if provider.r1_openrouter_options:
274            # Require providers that support the reasoning parameter
275            provider_options["require_parameters"] = True
276            # Prefer R1 providers with reasonable perf/quants
277            provider_options["order"] = ["Fireworks", "Together"]
278            # R1 providers with unreasonable quants
279            provider_options["ignore"] = ["DeepInfra"]
280
281        # Only set of this request is to get logprobs.
282        if (
283            provider.logprobs_openrouter_options
284            and self.base_adapter_config.top_logprobs is not None
285        ):
286            # Don't let OpenRouter choose a provider that doesn't support logprobs.
287            provider_options["require_parameters"] = True
288            # DeepInfra silently fails to return logprobs consistently.
289            provider_options["ignore"] = ["DeepInfra"]
290
291        if provider.openrouter_skip_required_parameters:
292            # Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params.
293            provider_options["require_parameters"] = False
294
295        if len(provider_options) > 0:
296            extra_body["provider"] = provider_options
297
298        return extra_body
def litellm_model_id(self) -> str:
300    def litellm_model_id(self) -> str:
301        # The model ID is an interesting combination of format and url endpoint.
302        # It specifics the provider URL/host, but this is overridden if you manually set an api url
303
304        if self._litellm_model_id:
305            return self._litellm_model_id
306
307        provider = self.model_provider()
308        if not provider.model_id:
309            raise ValueError("Model ID is required for OpenAI compatible models")
310
311        litellm_provider_name: str | None = None
312        is_custom = False
313        match provider.name:
314            case ModelProviderName.openrouter:
315                litellm_provider_name = "openrouter"
316            case ModelProviderName.openai:
317                litellm_provider_name = "openai"
318            case ModelProviderName.groq:
319                litellm_provider_name = "groq"
320            case ModelProviderName.anthropic:
321                litellm_provider_name = "anthropic"
322            case ModelProviderName.ollama:
323                # We don't let litellm use the Ollama API and muck with our requests. We use Ollama's OpenAI compatible API.
324                # This is because we're setting detailed features like response_format=json_schema and want lower level control.
325                is_custom = True
326            case ModelProviderName.gemini_api:
327                litellm_provider_name = "gemini"
328            case ModelProviderName.fireworks_ai:
329                litellm_provider_name = "fireworks_ai"
330            case ModelProviderName.amazon_bedrock:
331                litellm_provider_name = "bedrock"
332            case ModelProviderName.azure_openai:
333                litellm_provider_name = "azure"
334            case ModelProviderName.huggingface:
335                litellm_provider_name = "huggingface"
336            case ModelProviderName.vertex:
337                litellm_provider_name = "vertex_ai"
338            case ModelProviderName.together_ai:
339                litellm_provider_name = "together_ai"
340            case ModelProviderName.openai_compatible:
341                is_custom = True
342            case ModelProviderName.kiln_custom_registry:
343                is_custom = True
344            case ModelProviderName.kiln_fine_tune:
345                is_custom = True
346            case _:
347                raise_exhaustive_enum_error(provider.name)
348
349        if is_custom:
350            if self._api_base is None:
351                raise ValueError(
352                    "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)"
353                )
354            # Use openai as it's only used for format, not url
355            litellm_provider_name = "openai"
356
357        # Sholdn't be possible but keep type checker happy
358        if litellm_provider_name is None:
359            raise ValueError(
360                f"Provider name could not lookup valid litellm provider ID {provider.model_id}"
361            )
362
363        self._litellm_model_id = litellm_provider_name + "/" + provider.model_id
364        return self._litellm_model_id
async def build_completion_kwargs( self, provider: kiln_ai.adapters.ml_model_list.KilnModelProvider, messages: list[dict[str, typing.Any]], top_logprobs: int | None) -> dict[str, typing.Any]:
366    async def build_completion_kwargs(
367        self,
368        provider: KilnModelProvider,
369        messages: list[dict[str, Any]],
370        top_logprobs: int | None,
371    ) -> dict[str, Any]:
372        extra_body = self.build_extra_body(provider)
373
374        # Merge all parameters into a single kwargs dict for litellm
375        completion_kwargs = {
376            "model": self.litellm_model_id(),
377            "messages": messages,
378            "api_base": self._api_base,
379            "headers": self._headers,
380            **extra_body,
381            **self._additional_body_options,
382        }
383
384        # Response format: json_schema, json_instructions, json_mode, function_calling, etc
385        response_format_options = await self.response_format_options()
386        completion_kwargs.update(response_format_options)
387
388        if top_logprobs is not None:
389            completion_kwargs["logprobs"] = True
390            completion_kwargs["top_logprobs"] = top_logprobs
391
392        return completion_kwargs