kiln_ai.adapters.langchain_adapters

  1import os
  2from os import getenv
  3from typing import Any, Dict
  4
  5from langchain_aws import ChatBedrockConverse
  6from langchain_core.language_models import LanguageModelInput
  7from langchain_core.language_models.chat_models import BaseChatModel
  8from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
  9from langchain_core.messages.base import BaseMessage
 10from langchain_core.runnables import Runnable
 11from langchain_fireworks import ChatFireworks
 12from langchain_groq import ChatGroq
 13from langchain_ollama import ChatOllama
 14from langchain_openai import ChatOpenAI
 15from pydantic import BaseModel
 16
 17import kiln_ai.datamodel as datamodel
 18from kiln_ai.adapters.ollama_tools import (
 19    get_ollama_connection,
 20    ollama_base_url,
 21    ollama_model_installed,
 22)
 23from kiln_ai.utils.config import Config
 24
 25from .base_adapter import AdapterInfo, BaseAdapter, BasePromptBuilder, RunOutput
 26from .ml_model_list import KilnModelProvider, ModelProviderName
 27from .provider_tools import kiln_model_provider_from
 28
 29LangChainModelType = BaseChatModel | Runnable[LanguageModelInput, Dict | BaseModel]
 30
 31
 32class LangchainAdapter(BaseAdapter):
 33    _model: LangChainModelType | None = None
 34
 35    def __init__(
 36        self,
 37        kiln_task: datamodel.Task,
 38        custom_model: BaseChatModel | None = None,
 39        model_name: str | None = None,
 40        provider: str | None = None,
 41        prompt_builder: BasePromptBuilder | None = None,
 42        tags: list[str] | None = None,
 43    ):
 44        super().__init__(kiln_task, prompt_builder=prompt_builder, tags=tags)
 45        if custom_model is not None:
 46            self._model = custom_model
 47
 48            # Attempt to infer model provider and name from custom model
 49            self.model_provider = "custom.langchain:" + custom_model.__class__.__name__
 50            self.model_name = "custom.langchain:unknown_model"
 51            if hasattr(custom_model, "model_name") and isinstance(
 52                getattr(custom_model, "model_name"), str
 53            ):
 54                self.model_name = "custom.langchain:" + getattr(
 55                    custom_model, "model_name"
 56                )
 57            if hasattr(custom_model, "model") and isinstance(
 58                getattr(custom_model, "model"), str
 59            ):
 60                self.model_name = "custom.langchain:" + getattr(custom_model, "model")
 61        elif model_name is not None:
 62            self.model_name = model_name
 63            self.model_provider = provider or "custom.langchain.default_provider"
 64        else:
 65            raise ValueError(
 66                "model_name and provider must be provided if custom_model is not provided"
 67            )
 68
 69    async def model(self) -> LangChainModelType:
 70        # cached model
 71        if self._model:
 72            return self._model
 73
 74        self._model = await langchain_model_from(self.model_name, self.model_provider)
 75
 76        if self.has_structured_output():
 77            if not hasattr(self._model, "with_structured_output") or not callable(
 78                getattr(self._model, "with_structured_output")
 79            ):
 80                raise ValueError(
 81                    f"model {self._model} does not support structured output, cannot use output_json_schema"
 82                )
 83            # Langchain expects title/description to be at top level, on top of json schema
 84            output_schema = self.kiln_task.output_schema()
 85            if output_schema is None:
 86                raise ValueError(
 87                    f"output_json_schema is not valid json: {self.kiln_task.output_json_schema}"
 88                )
 89            output_schema["title"] = "task_response"
 90            output_schema["description"] = "A response from the task"
 91            with_structured_output_options = await get_structured_output_options(
 92                self.model_name, self.model_provider
 93            )
 94            self._model = self._model.with_structured_output(
 95                output_schema,
 96                include_raw=True,
 97                **with_structured_output_options,
 98            )
 99        return self._model
100
101    async def _run(self, input: Dict | str) -> RunOutput:
102        model = await self.model()
103        chain = model
104        intermediate_outputs = {}
105
106        prompt = self.build_prompt()
107        user_msg = self.prompt_builder.build_user_message(input)
108        messages = [
109            SystemMessage(content=prompt),
110            HumanMessage(content=user_msg),
111        ]
112
113        # COT with structured output
114        cot_prompt = self.prompt_builder.chain_of_thought_prompt()
115        if cot_prompt and self.has_structured_output():
116            # Base model (without structured output) used for COT message
117            base_model = await langchain_model_from(
118                self.model_name, self.model_provider
119            )
120            messages.append(
121                SystemMessage(content=cot_prompt),
122            )
123
124            cot_messages = [*messages]
125            cot_response = await base_model.ainvoke(cot_messages)
126            intermediate_outputs["chain_of_thought"] = cot_response.content
127            messages.append(AIMessage(content=cot_response.content))
128            messages.append(
129                SystemMessage(content="Considering the above, return a final result.")
130            )
131        elif cot_prompt:
132            messages.append(SystemMessage(content=cot_prompt))
133
134        response = await chain.ainvoke(messages)
135
136        if self.has_structured_output():
137            if (
138                not isinstance(response, dict)
139                or "parsed" not in response
140                or not isinstance(response["parsed"], dict)
141            ):
142                raise RuntimeError(f"structured response not returned: {response}")
143            structured_response = response["parsed"]
144            return RunOutput(
145                output=self._munge_response(structured_response),
146                intermediate_outputs=intermediate_outputs,
147            )
148        else:
149            if not isinstance(response, BaseMessage):
150                raise RuntimeError(f"response is not a BaseMessage: {response}")
151            text_content = response.content
152            if not isinstance(text_content, str):
153                raise RuntimeError(f"response is not a string: {text_content}")
154            return RunOutput(
155                output=text_content,
156                intermediate_outputs=intermediate_outputs,
157            )
158
159    def adapter_info(self) -> AdapterInfo:
160        return AdapterInfo(
161            model_name=self.model_name,
162            model_provider=self.model_provider,
163            adapter_name="kiln_langchain_adapter",
164            prompt_builder_name=self.prompt_builder.__class__.prompt_builder_name(),
165        )
166
167    def _munge_response(self, response: Dict) -> Dict:
168        # Mistral Large tool calling format is a bit different. Convert to standard format.
169        if (
170            "name" in response
171            and response["name"] == "task_response"
172            and "arguments" in response
173        ):
174            return response["arguments"]
175        return response
176
177
178async def get_structured_output_options(
179    model_name: str, model_provider: str
180) -> Dict[str, Any]:
181    finetune_provider = await kiln_model_provider_from(model_name, model_provider)
182    if finetune_provider and finetune_provider.adapter_options.get("langchain"):
183        return finetune_provider.adapter_options["langchain"].get(
184            "with_structured_output_options", {}
185        )
186    return {}
187
188
189async def langchain_model_from(
190    name: str, provider_name: str | None = None
191) -> BaseChatModel:
192    provider = await kiln_model_provider_from(name, provider_name)
193    return await langchain_model_from_provider(provider, name)
194
195
196async def langchain_model_from_provider(
197    provider: KilnModelProvider, model_name: str
198) -> BaseChatModel:
199    if provider.name == ModelProviderName.openai:
200        api_key = Config.shared().open_ai_api_key
201        return ChatOpenAI(**provider.provider_options, openai_api_key=api_key)  # type: ignore[arg-type]
202    elif provider.name == ModelProviderName.openai_compatible:
203        # See provider_tools.py for how base_url, key and other parameters are set
204        return ChatOpenAI(**provider.provider_options)  # type: ignore[arg-type]
205    elif provider.name == ModelProviderName.groq:
206        api_key = Config.shared().groq_api_key
207        if api_key is None:
208            raise ValueError(
209                "Attempted to use Groq without an API key set. "
210                "Get your API key from https://console.groq.com/keys"
211            )
212        return ChatGroq(**provider.provider_options, groq_api_key=api_key)  # type: ignore[arg-type]
213    elif provider.name == ModelProviderName.amazon_bedrock:
214        api_key = Config.shared().bedrock_access_key
215        secret_key = Config.shared().bedrock_secret_key
216        # langchain doesn't allow passing these, so ugly hack to set env vars
217        os.environ["AWS_ACCESS_KEY_ID"] = api_key
218        os.environ["AWS_SECRET_ACCESS_KEY"] = secret_key
219        return ChatBedrockConverse(
220            **provider.provider_options,
221        )
222    elif provider.name == ModelProviderName.fireworks_ai:
223        api_key = Config.shared().fireworks_api_key
224        return ChatFireworks(**provider.provider_options, api_key=api_key)
225    elif provider.name == ModelProviderName.ollama:
226        # Ollama model naming is pretty flexible. We try a few versions of the model name
227        potential_model_names = []
228        if "model" in provider.provider_options:
229            potential_model_names.append(provider.provider_options["model"])
230        if "model_aliases" in provider.provider_options:
231            potential_model_names.extend(provider.provider_options["model_aliases"])
232
233        # Get the list of models Ollama supports
234        ollama_connection = await get_ollama_connection()
235        if ollama_connection is None:
236            raise ValueError("Failed to connect to Ollama. Ensure Ollama is running.")
237
238        for model_name in potential_model_names:
239            if ollama_model_installed(ollama_connection, model_name):
240                return ChatOllama(model=model_name, base_url=ollama_base_url())
241
242        raise ValueError(f"Model {model_name} not installed on Ollama")
243    elif provider.name == ModelProviderName.openrouter:
244        api_key = Config.shared().open_router_api_key
245        base_url = getenv("OPENROUTER_BASE_URL") or "https://openrouter.ai/api/v1"
246        return ChatOpenAI(
247            **provider.provider_options,
248            openai_api_key=api_key,  # type: ignore[arg-type]
249            openai_api_base=base_url,  # type: ignore[arg-type]
250            default_headers={
251                "HTTP-Referer": "https://getkiln.ai/openrouter",
252                "X-Title": "KilnAI",
253            },
254        )
255    else:
256        raise ValueError(f"Invalid model or provider: {model_name} - {provider.name}")
LangChainModelType = typing.Union[langchain_core.language_models.chat_models.BaseChatModel, langchain_core.runnables.base.Runnable[typing.Union[langchain_core.prompt_values.PromptValue, str, collections.abc.Sequence[typing.Union[langchain_core.messages.base.BaseMessage, list[str], tuple[str, str], str, dict[str, typing.Any]]]], typing.Union[typing.Dict, pydantic.main.BaseModel]]]
class LangchainAdapter(kiln_ai.adapters.base_adapter.BaseAdapter):
 33class LangchainAdapter(BaseAdapter):
 34    _model: LangChainModelType | None = None
 35
 36    def __init__(
 37        self,
 38        kiln_task: datamodel.Task,
 39        custom_model: BaseChatModel | None = None,
 40        model_name: str | None = None,
 41        provider: str | None = None,
 42        prompt_builder: BasePromptBuilder | None = None,
 43        tags: list[str] | None = None,
 44    ):
 45        super().__init__(kiln_task, prompt_builder=prompt_builder, tags=tags)
 46        if custom_model is not None:
 47            self._model = custom_model
 48
 49            # Attempt to infer model provider and name from custom model
 50            self.model_provider = "custom.langchain:" + custom_model.__class__.__name__
 51            self.model_name = "custom.langchain:unknown_model"
 52            if hasattr(custom_model, "model_name") and isinstance(
 53                getattr(custom_model, "model_name"), str
 54            ):
 55                self.model_name = "custom.langchain:" + getattr(
 56                    custom_model, "model_name"
 57                )
 58            if hasattr(custom_model, "model") and isinstance(
 59                getattr(custom_model, "model"), str
 60            ):
 61                self.model_name = "custom.langchain:" + getattr(custom_model, "model")
 62        elif model_name is not None:
 63            self.model_name = model_name
 64            self.model_provider = provider or "custom.langchain.default_provider"
 65        else:
 66            raise ValueError(
 67                "model_name and provider must be provided if custom_model is not provided"
 68            )
 69
 70    async def model(self) -> LangChainModelType:
 71        # cached model
 72        if self._model:
 73            return self._model
 74
 75        self._model = await langchain_model_from(self.model_name, self.model_provider)
 76
 77        if self.has_structured_output():
 78            if not hasattr(self._model, "with_structured_output") or not callable(
 79                getattr(self._model, "with_structured_output")
 80            ):
 81                raise ValueError(
 82                    f"model {self._model} does not support structured output, cannot use output_json_schema"
 83                )
 84            # Langchain expects title/description to be at top level, on top of json schema
 85            output_schema = self.kiln_task.output_schema()
 86            if output_schema is None:
 87                raise ValueError(
 88                    f"output_json_schema is not valid json: {self.kiln_task.output_json_schema}"
 89                )
 90            output_schema["title"] = "task_response"
 91            output_schema["description"] = "A response from the task"
 92            with_structured_output_options = await get_structured_output_options(
 93                self.model_name, self.model_provider
 94            )
 95            self._model = self._model.with_structured_output(
 96                output_schema,
 97                include_raw=True,
 98                **with_structured_output_options,
 99            )
100        return self._model
101
102    async def _run(self, input: Dict | str) -> RunOutput:
103        model = await self.model()
104        chain = model
105        intermediate_outputs = {}
106
107        prompt = self.build_prompt()
108        user_msg = self.prompt_builder.build_user_message(input)
109        messages = [
110            SystemMessage(content=prompt),
111            HumanMessage(content=user_msg),
112        ]
113
114        # COT with structured output
115        cot_prompt = self.prompt_builder.chain_of_thought_prompt()
116        if cot_prompt and self.has_structured_output():
117            # Base model (without structured output) used for COT message
118            base_model = await langchain_model_from(
119                self.model_name, self.model_provider
120            )
121            messages.append(
122                SystemMessage(content=cot_prompt),
123            )
124
125            cot_messages = [*messages]
126            cot_response = await base_model.ainvoke(cot_messages)
127            intermediate_outputs["chain_of_thought"] = cot_response.content
128            messages.append(AIMessage(content=cot_response.content))
129            messages.append(
130                SystemMessage(content="Considering the above, return a final result.")
131            )
132        elif cot_prompt:
133            messages.append(SystemMessage(content=cot_prompt))
134
135        response = await chain.ainvoke(messages)
136
137        if self.has_structured_output():
138            if (
139                not isinstance(response, dict)
140                or "parsed" not in response
141                or not isinstance(response["parsed"], dict)
142            ):
143                raise RuntimeError(f"structured response not returned: {response}")
144            structured_response = response["parsed"]
145            return RunOutput(
146                output=self._munge_response(structured_response),
147                intermediate_outputs=intermediate_outputs,
148            )
149        else:
150            if not isinstance(response, BaseMessage):
151                raise RuntimeError(f"response is not a BaseMessage: {response}")
152            text_content = response.content
153            if not isinstance(text_content, str):
154                raise RuntimeError(f"response is not a string: {text_content}")
155            return RunOutput(
156                output=text_content,
157                intermediate_outputs=intermediate_outputs,
158            )
159
160    def adapter_info(self) -> AdapterInfo:
161        return AdapterInfo(
162            model_name=self.model_name,
163            model_provider=self.model_provider,
164            adapter_name="kiln_langchain_adapter",
165            prompt_builder_name=self.prompt_builder.__class__.prompt_builder_name(),
166        )
167
168    def _munge_response(self, response: Dict) -> Dict:
169        # Mistral Large tool calling format is a bit different. Convert to standard format.
170        if (
171            "name" in response
172            and response["name"] == "task_response"
173            and "arguments" in response
174        ):
175            return response["arguments"]
176        return response

Base class for AI model adapters that handle task execution.

This abstract class provides the foundation for implementing model-specific adapters that can process tasks with structured or unstructured inputs/outputs. It handles input/output validation, prompt building, and run tracking.

Attributes: prompt_builder (BasePromptBuilder): Builder for constructing prompts for the model kiln_task (Task): The task configuration and metadata output_schema (dict | None): JSON schema for validating structured outputs input_schema (dict | None): JSON schema for validating structured inputs

LangchainAdapter( kiln_task: kiln_ai.datamodel.Task, custom_model: langchain_core.language_models.chat_models.BaseChatModel | None = None, model_name: str | None = None, provider: str | None = None, prompt_builder: kiln_ai.adapters.prompt_builders.BasePromptBuilder | None = None, tags: list[str] | None = None)
36    def __init__(
37        self,
38        kiln_task: datamodel.Task,
39        custom_model: BaseChatModel | None = None,
40        model_name: str | None = None,
41        provider: str | None = None,
42        prompt_builder: BasePromptBuilder | None = None,
43        tags: list[str] | None = None,
44    ):
45        super().__init__(kiln_task, prompt_builder=prompt_builder, tags=tags)
46        if custom_model is not None:
47            self._model = custom_model
48
49            # Attempt to infer model provider and name from custom model
50            self.model_provider = "custom.langchain:" + custom_model.__class__.__name__
51            self.model_name = "custom.langchain:unknown_model"
52            if hasattr(custom_model, "model_name") and isinstance(
53                getattr(custom_model, "model_name"), str
54            ):
55                self.model_name = "custom.langchain:" + getattr(
56                    custom_model, "model_name"
57                )
58            if hasattr(custom_model, "model") and isinstance(
59                getattr(custom_model, "model"), str
60            ):
61                self.model_name = "custom.langchain:" + getattr(custom_model, "model")
62        elif model_name is not None:
63            self.model_name = model_name
64            self.model_provider = provider or "custom.langchain.default_provider"
65        else:
66            raise ValueError(
67                "model_name and provider must be provided if custom_model is not provided"
68            )
async def model( self) -> Union[langchain_core.language_models.chat_models.BaseChatModel, langchain_core.runnables.base.Runnable[Union[langchain_core.prompt_values.PromptValue, str, Sequence[Union[langchain_core.messages.base.BaseMessage, list[str], tuple[str, str], str, dict[str, Any]]]], Union[Dict, pydantic.main.BaseModel]]]:
 70    async def model(self) -> LangChainModelType:
 71        # cached model
 72        if self._model:
 73            return self._model
 74
 75        self._model = await langchain_model_from(self.model_name, self.model_provider)
 76
 77        if self.has_structured_output():
 78            if not hasattr(self._model, "with_structured_output") or not callable(
 79                getattr(self._model, "with_structured_output")
 80            ):
 81                raise ValueError(
 82                    f"model {self._model} does not support structured output, cannot use output_json_schema"
 83                )
 84            # Langchain expects title/description to be at top level, on top of json schema
 85            output_schema = self.kiln_task.output_schema()
 86            if output_schema is None:
 87                raise ValueError(
 88                    f"output_json_schema is not valid json: {self.kiln_task.output_json_schema}"
 89                )
 90            output_schema["title"] = "task_response"
 91            output_schema["description"] = "A response from the task"
 92            with_structured_output_options = await get_structured_output_options(
 93                self.model_name, self.model_provider
 94            )
 95            self._model = self._model.with_structured_output(
 96                output_schema,
 97                include_raw=True,
 98                **with_structured_output_options,
 99            )
100        return self._model
def adapter_info(self) -> kiln_ai.adapters.base_adapter.AdapterInfo:
160    def adapter_info(self) -> AdapterInfo:
161        return AdapterInfo(
162            model_name=self.model_name,
163            model_provider=self.model_provider,
164            adapter_name="kiln_langchain_adapter",
165            prompt_builder_name=self.prompt_builder.__class__.prompt_builder_name(),
166        )
async def get_structured_output_options(model_name: str, model_provider: str) -> Dict[str, Any]:
179async def get_structured_output_options(
180    model_name: str, model_provider: str
181) -> Dict[str, Any]:
182    finetune_provider = await kiln_model_provider_from(model_name, model_provider)
183    if finetune_provider and finetune_provider.adapter_options.get("langchain"):
184        return finetune_provider.adapter_options["langchain"].get(
185            "with_structured_output_options", {}
186        )
187    return {}
async def langchain_model_from( name: str, provider_name: str | None = None) -> langchain_core.language_models.chat_models.BaseChatModel:
190async def langchain_model_from(
191    name: str, provider_name: str | None = None
192) -> BaseChatModel:
193    provider = await kiln_model_provider_from(name, provider_name)
194    return await langchain_model_from_provider(provider, name)
async def langchain_model_from_provider( provider: kiln_ai.adapters.ml_model_list.KilnModelProvider, model_name: str) -> langchain_core.language_models.chat_models.BaseChatModel:
197async def langchain_model_from_provider(
198    provider: KilnModelProvider, model_name: str
199) -> BaseChatModel:
200    if provider.name == ModelProviderName.openai:
201        api_key = Config.shared().open_ai_api_key
202        return ChatOpenAI(**provider.provider_options, openai_api_key=api_key)  # type: ignore[arg-type]
203    elif provider.name == ModelProviderName.openai_compatible:
204        # See provider_tools.py for how base_url, key and other parameters are set
205        return ChatOpenAI(**provider.provider_options)  # type: ignore[arg-type]
206    elif provider.name == ModelProviderName.groq:
207        api_key = Config.shared().groq_api_key
208        if api_key is None:
209            raise ValueError(
210                "Attempted to use Groq without an API key set. "
211                "Get your API key from https://console.groq.com/keys"
212            )
213        return ChatGroq(**provider.provider_options, groq_api_key=api_key)  # type: ignore[arg-type]
214    elif provider.name == ModelProviderName.amazon_bedrock:
215        api_key = Config.shared().bedrock_access_key
216        secret_key = Config.shared().bedrock_secret_key
217        # langchain doesn't allow passing these, so ugly hack to set env vars
218        os.environ["AWS_ACCESS_KEY_ID"] = api_key
219        os.environ["AWS_SECRET_ACCESS_KEY"] = secret_key
220        return ChatBedrockConverse(
221            **provider.provider_options,
222        )
223    elif provider.name == ModelProviderName.fireworks_ai:
224        api_key = Config.shared().fireworks_api_key
225        return ChatFireworks(**provider.provider_options, api_key=api_key)
226    elif provider.name == ModelProviderName.ollama:
227        # Ollama model naming is pretty flexible. We try a few versions of the model name
228        potential_model_names = []
229        if "model" in provider.provider_options:
230            potential_model_names.append(provider.provider_options["model"])
231        if "model_aliases" in provider.provider_options:
232            potential_model_names.extend(provider.provider_options["model_aliases"])
233
234        # Get the list of models Ollama supports
235        ollama_connection = await get_ollama_connection()
236        if ollama_connection is None:
237            raise ValueError("Failed to connect to Ollama. Ensure Ollama is running.")
238
239        for model_name in potential_model_names:
240            if ollama_model_installed(ollama_connection, model_name):
241                return ChatOllama(model=model_name, base_url=ollama_base_url())
242
243        raise ValueError(f"Model {model_name} not installed on Ollama")
244    elif provider.name == ModelProviderName.openrouter:
245        api_key = Config.shared().open_router_api_key
246        base_url = getenv("OPENROUTER_BASE_URL") or "https://openrouter.ai/api/v1"
247        return ChatOpenAI(
248            **provider.provider_options,
249            openai_api_key=api_key,  # type: ignore[arg-type]
250            openai_api_base=base_url,  # type: ignore[arg-type]
251            default_headers={
252                "HTTP-Referer": "https://getkiln.ai/openrouter",
253                "X-Title": "KilnAI",
254            },
255        )
256    else:
257        raise ValueError(f"Invalid model or provider: {model_name} - {provider.name}")