From 5e2cae62418414d5e488bc2ee24174f40b6b92a6 Mon Sep 17 00:00:00 2001 From: Pouyanpi <13303554+Pouyanpi@users.noreply.github.com> Date: Thu, 2 Apr 2026 12:40:52 +0200 Subject: [PATCH 1/2] feat(llm): add LangChain adapter and framework registry Add LangChainLLMAdapter wrapping BaseChatModel behind the LLMModel protocol, LangChainFramework implementing create_model(), and a pluggable framework registry. Pure additive, no existing files modified. Part of the LangChain decoupling epic feat(llm): implement provider_url on LangChainLLMAdapter Extract endpoint URL from the underlying LangChain LLM by checking common base URL attributes and the nested client object. --- .../integrations/langchain/llm_adapter.py | 400 ++++++++++++++++++ .../integrations/langchain/message_utils.py | 30 ++ nemoguardrails/llm/frameworks.py | 55 +++ tests/llm/test_frameworks.py | 78 ++++ tests/test_langchain_llm_adapter.py | 348 +++++++++++++++ 5 files changed, 911 insertions(+) create mode 100644 nemoguardrails/integrations/langchain/llm_adapter.py create mode 100644 nemoguardrails/llm/frameworks.py create mode 100644 tests/llm/test_frameworks.py create mode 100644 tests/test_langchain_llm_adapter.py diff --git a/nemoguardrails/integrations/langchain/llm_adapter.py b/nemoguardrails/integrations/langchain/llm_adapter.py new file mode 100644 index 0000000000..a962cd7a08 --- /dev/null +++ b/nemoguardrails/integrations/langchain/llm_adapter.py @@ -0,0 +1,400 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import uuid +from typing import Any, AsyncIterator, Dict, List, Optional, Union + +from nemoguardrails.types import ( + ChatMessage, + FinishReason, + LLMModel, + LLMResponse, + LLMResponseChunk, + ToolCall, + ToolCallFunction, + UsageInfo, +) + + +def _infer_model_name(llm: Any): + """Helper to infer the model name based from an LLM instance. + + Because not all models implement correctly _identifying_params from LangChain, we have to + try to do this manually. + """ + for attr in ["model", "model_name"]: + if hasattr(llm, attr): + val = getattr(llm, attr) + if isinstance(val, str): + return val + + model_kwargs = getattr(llm, "model_kwargs", None) + if model_kwargs and isinstance(model_kwargs, Dict): + for attr in ["model", "model_name", "name"]: + val = model_kwargs.get(attr) + if isinstance(val, str): + return val + + # If we still can't figure out, return "unknown". + return "unknown" + + +def _infer_provider_from_module(llm: Any) -> Optional[str]: + """Infer provider name from the LLM's module path. + + This function extracts the provider name from LangChain package naming conventions: + - langchain_openai -> openai + - langchain_anthropic -> anthropic + - langchain_google_genai -> google_genai + - langchain_nvidia_ai_endpoints -> nvidia_ai_endpoints + - langchain_community.chat_models.ollama -> ollama + + For patched/wrapped classes, checks base classes as well. + + Args: + llm: The LLM instance + + Returns: + The inferred provider name, or None if it cannot be determined + """ + module = type(llm).__module__ + + if module.startswith("langchain_"): + package = module.split(".")[0] + provider = package.replace("langchain_", "") + + if provider == "community": + parts = module.split(".") + if len(parts) >= 3: + provider = parts[-1] + return provider + else: + return provider + + for base_class in type(llm).__mro__[1:]: + base_module = base_class.__module__ + if base_module.startswith("langchain_"): + package = base_module.split(".")[0] + provider = package.replace("langchain_", "") + + if provider == "community": + parts = base_module.split(".") + if len(parts) >= 3: + provider = parts[-1] + return provider + else: + return provider + + return None + + +_BASE_URL_ATTRIBUTES = [ + "base_url", + "endpoint_url", + "server_url", + "azure_endpoint", + "openai_api_base", + "api_base", + "api_host", + "endpoint", +] + + +class LangChainLLMAdapter: + def __init__(self, llm): + self._llm = llm + + @property + def raw_llm(self) -> Any: + return self._llm + + @property + def model_name(self) -> str: + return _infer_model_name(self._llm) + + @property + def provider_name(self) -> Optional[str]: + return _infer_provider_from_module(self._llm) + + @property + def provider_url(self) -> Optional[str]: + # temp: uses _BASE_URL_ATTRIBUTES which duplicates utils.py BASE_URL_ATTRIBUTES. + # utils.py copy will be removed in stack-3 when it switches to model.provider_url. + for attr in _BASE_URL_ATTRIBUTES: + value = getattr(self._llm, attr, None) + if value: + return str(value) + client = getattr(self._llm, "client", None) + if client and hasattr(client, "base_url"): + return str(client.base_url) + return None + + def _filter_reasoning_model_params(self, params: Optional[dict]) -> Optional[dict]: + if not params or "temperature" not in params: + return params + + model_name = _infer_model_name(self._llm).lower() + + is_openai_reasoning_model = ( + model_name.startswith("o1") + or model_name.startswith("o3") + or (model_name.startswith("gpt-5") and "chat" not in model_name) + ) + + if is_openai_reasoning_model: + filtered = params.copy() + filtered.pop("temperature", None) + return filtered + + return params + + def _prepare_llm(self, kwargs: dict): + kwargs = self._filter_reasoning_model_params(kwargs) or {} + llm = self._llm + if kwargs: + llm = llm.bind(**kwargs) + return llm + + def _to_langchain_input(self, prompt): + if isinstance(prompt, list): + from nemoguardrails.integrations.langchain.message_utils import ( + chatmessages_to_langchain_messages, + ) + + return chatmessages_to_langchain_messages(prompt) + return prompt + + async def generate( + self, + prompt: Union[str, List[ChatMessage]], + *, + stop: Optional[List[str]] = None, + **kwargs, + ) -> LLMResponse: + llm = self._prepare_llm(kwargs) + messages = self._to_langchain_input(prompt) + response = await llm.ainvoke(messages, stop=stop) + return _langchain_response_to_llm_response(response) + + async def stream( + self, + prompt: Union[str, List[ChatMessage]], + *, + stop: Optional[List[str]] = None, + **kwargs, + ) -> AsyncIterator[LLMResponseChunk]: + llm = self._prepare_llm(kwargs) + messages = self._to_langchain_input(prompt) + async for chunk in llm.astream(messages, stop=stop): + yield _langchain_chunk_to_llm_response_chunk(chunk) + + +class LangChainFramework: + def create_model( + self, + model_name: str, + provider_name: str, + model_kwargs: Optional[Dict[str, Any]] = None, + ) -> LLMModel: + from nemoguardrails.llm.models.langchain_initializer import ( + init_langchain_model, + ) + + kwargs = dict(model_kwargs) if model_kwargs else {} + mode = kwargs.pop("mode", "chat") + + raw_llm = init_langchain_model( + model_name=model_name, + provider_name=provider_name, + mode=mode, + kwargs=kwargs, + ) + return LangChainLLMAdapter(raw_llm) + + +_FINISH_REASON_MAP: Dict[str, FinishReason] = { + "stop": "stop", + "end_turn": "stop", + "length": "length", + "max_tokens": "length", + "tool_calls": "tool_calls", + "tool_use": "tool_calls", + "content_filter": "content_filter", +} + + +def _map_finish_reason(raw: Optional[str]) -> Optional[FinishReason]: + if raw is None: + return None + return _FINISH_REASON_MAP.get(raw, "other") + + +def _build_usage_info(raw: Any) -> Optional[UsageInfo]: + if raw is None: + return None + if not isinstance(raw, dict): + try: + raw = dict(raw) + except (TypeError, ValueError): + return None + if not raw: + return None + return UsageInfo( + input_tokens=raw.get("input_tokens", raw.get("prompt_tokens", 0)), + output_tokens=raw.get("output_tokens", raw.get("completion_tokens", 0)), + total_tokens=raw.get("total_tokens", 0), + reasoning_tokens=raw.get("reasoning_tokens"), + cached_tokens=raw.get("cached_tokens", raw.get("cache_read_input_tokens")), + ) + + +def _extract_reasoning(response) -> Optional[str]: + content_blocks = getattr(response, "content_blocks", None) + if content_blocks: + for block in content_blocks: + if isinstance(block, dict) and block.get("type") == "reasoning": + val = block.get("reasoning") + if val: + return val + + additional_kwargs = getattr(response, "additional_kwargs", None) + if additional_kwargs and isinstance(additional_kwargs, dict): + val = additional_kwargs.get("reasoning_content") + if val: + return val + + return None + + +def _langchain_response_to_llm_response(response) -> LLMResponse: + content = getattr(response, "content", None) + if content is None: + content = str(response) + + reasoning = _extract_reasoning(response) + + raw_tool_calls = getattr(response, "tool_calls", None) + tool_calls = None + if raw_tool_calls: + tool_calls = [] + for tc in raw_tool_calls: + if isinstance(tc, dict): + tool_calls.append( + ToolCall( + id=tc.get("id") or str(uuid.uuid4()), + type="function", + function=ToolCallFunction( + name=tc.get("name", ""), + arguments=tc.get("args", {}), + ), + ) + ) + else: + tool_calls.append( + ToolCall( + id=getattr(tc, "id", None) or str(uuid.uuid4()), + type="function", + function=ToolCallFunction( + name=getattr(tc, "name", ""), + arguments=getattr(tc, "args", {}), + ), + ) + ) + + response_metadata = getattr(response, "response_metadata", None) or {} + additional_kwargs = getattr(response, "additional_kwargs", None) or {} + + usage_metadata = getattr(response, "usage_metadata", None) + usage = _build_usage_info(usage_metadata) + if usage is None and response_metadata: + token_usage = response_metadata.get("token_usage") or response_metadata.get("usage") + if token_usage: + usage = _build_usage_info(token_usage) + + model = response_metadata.get("model_name") or response_metadata.get("model") + + raw_finish = response_metadata.get("finish_reason") or response_metadata.get("stop_reason") + finish_reason = _map_finish_reason(raw_finish) + + stop_sequence = response_metadata.get("stop_sequence") + + request_id = response_metadata.get("id") or response_metadata.get("request_id") + + extracted_keys = { + "model_name", + "model", + "finish_reason", + "stop_reason", + "stop_sequence", + "id", + "request_id", + "token_usage", + "usage", + } + reasoning_keys = {"reasoning_content"} + provider_metadata: Dict[str, Any] = {} + for k, v in response_metadata.items(): + if k not in extracted_keys: + provider_metadata[k] = v + for k, v in additional_kwargs.items(): + if k not in reasoning_keys and k not in provider_metadata: + provider_metadata[k] = v + + return LLMResponse( + content=content, + reasoning=reasoning, + tool_calls=tool_calls, + model=model, + finish_reason=finish_reason, + stop_sequence=stop_sequence, + request_id=request_id, + usage=usage, + provider_metadata=provider_metadata if provider_metadata else None, + ) + + +def _langchain_chunk_to_llm_response_chunk(chunk) -> LLMResponseChunk: + content = getattr(chunk, "content", None) + if content is None: + content = getattr(chunk, "text", None) + if content is None: + content = str(chunk) + + response_metadata = getattr(chunk, "response_metadata", None) or {} + generation_info = getattr(chunk, "generation_info", None) or {} + + usage_metadata = getattr(chunk, "usage_metadata", None) + usage = _build_usage_info(usage_metadata) + if usage is None and response_metadata: + token_usage = response_metadata.get("token_usage") or response_metadata.get("usage") + if token_usage: + usage = _build_usage_info(token_usage) + if usage is None and generation_info: + token_usage = generation_info.get("token_usage") or generation_info.get("usage") + if token_usage: + usage = _build_usage_info(token_usage) + + provider_metadata: Dict[str, Any] = {} + for k, v in response_metadata.items(): + provider_metadata[k] = v + for k, v in generation_info.items(): + if k not in provider_metadata: + provider_metadata[k] = v + + return LLMResponseChunk( + delta_content=content, + usage=usage, + provider_metadata=provider_metadata if provider_metadata else None, + ) diff --git a/nemoguardrails/integrations/langchain/message_utils.py b/nemoguardrails/integrations/langchain/message_utils.py index 5ccb3d8803..cec37b0485 100644 --- a/nemoguardrails/integrations/langchain/message_utils.py +++ b/nemoguardrails/integrations/langchain/message_utils.py @@ -147,6 +147,36 @@ def is_base_message(obj: Any) -> bool: return isinstance(obj, BaseMessage) +def chatmessage_to_langchain_message(msg: "ChatMessage") -> BaseMessage: + from nemoguardrails.types import Role + + content = msg.content or "" + if msg.role == Role.USER: + return HumanMessage(content=content) + elif msg.role == Role.SYSTEM: + return SystemMessage(content=content) + elif msg.role == Role.TOOL: + return ToolMessage(content=content, tool_call_id=msg.tool_call_id or "") + elif msg.role == Role.ASSISTANT: + kwargs: Dict[str, Any] = {} + if msg.tool_calls: + kwargs["tool_calls"] = [ + { + "name": tc.function.name, + "args": tc.function.arguments, + "id": tc.id, + "type": "tool_call", + } + for tc in msg.tool_calls + ] + return AIMessage(content=content, **kwargs) + return HumanMessage(content=content) + + +def chatmessages_to_langchain_messages(msgs: List["ChatMessage"]) -> List[BaseMessage]: + return [chatmessage_to_langchain_message(m) for m in msgs] + + def is_ai_message(obj: Any) -> bool: """Check if an object is an AIMessage.""" return isinstance(obj, AIMessage) diff --git a/nemoguardrails/llm/frameworks.py b/nemoguardrails/llm/frameworks.py new file mode 100644 index 0000000000..143dd034c0 --- /dev/null +++ b/nemoguardrails/llm/frameworks.py @@ -0,0 +1,55 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Dict + +from nemoguardrails.types import LLMFramework + +_frameworks: Dict[str, LLMFramework] = {} +_default_framework: str = os.environ.get("NEMOGUARDRAILS_LLM_FRAMEWORK", "langchain") + + +def register_framework(name: str, framework: LLMFramework) -> None: + if name in _frameworks: + raise ValueError(f"Framework '{name}' is already registered.") + _frameworks[name] = framework + + +def get_framework(name: str) -> LLMFramework: + if name not in _frameworks: + if name == "langchain": + from nemoguardrails.integrations.langchain.llm_adapter import LangChainFramework + + _frameworks["langchain"] = LangChainFramework() + else: + available = list(_frameworks.keys()) + raise KeyError(f"Unknown framework '{name}'. Available frameworks: {available}") + return _frameworks[name] + + +def set_default_framework(name: str) -> None: + global _default_framework + _default_framework = name + + +def get_default_framework() -> str: + return _default_framework + + +def _reset_frameworks() -> None: + global _default_framework + _frameworks.clear() + _default_framework = os.environ.get("NEMOGUARDRAILS_LLM_FRAMEWORK", "langchain") diff --git a/tests/llm/test_frameworks.py b/tests/llm/test_frameworks.py new file mode 100644 index 0000000000..fe33db91a9 --- /dev/null +++ b/tests/llm/test_frameworks.py @@ -0,0 +1,78 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock + +import pytest + +from nemoguardrails.llm.frameworks import ( + _reset_frameworks, + get_default_framework, + get_framework, + register_framework, + set_default_framework, +) +from nemoguardrails.types import LLMModel + + +@pytest.fixture(autouse=True) +def clean_registry(): + yield + _reset_frameworks() + + +class FakeFramework: + def create_model(self, model_name, provider_name, model_kwargs=None): + return MagicMock(spec=LLMModel) + + +class TestRegistry: + def test_register_and_get_framework(self): + fw = FakeFramework() + register_framework("fake", fw) + assert get_framework("fake") is fw + + def test_register_duplicate_raises_valueerror(self): + register_framework("dup", FakeFramework()) + with pytest.raises(ValueError, match="already registered"): + register_framework("dup", FakeFramework()) + + def test_get_unregistered_raises_keyerror(self): + with pytest.raises(KeyError, match="Unknown framework"): + get_framework("nonexistent") + + def test_langchain_lazy_auto_registration(self): + fw = get_framework("langchain") + from nemoguardrails.integrations.langchain.llm_adapter import LangChainFramework + + assert isinstance(fw, LangChainFramework) + + def test_set_and_get_default_framework(self): + set_default_framework("custom") + assert get_default_framework() == "custom" + + def test_default_is_langchain(self): + assert get_default_framework() == "langchain" + + def test_default_from_env_var(self, monkeypatch): + monkeypatch.setenv("NEMOGUARDRAILS_LLM_FRAMEWORK", "litellm") + _reset_frameworks() + assert get_default_framework() == "litellm" + + def test_reset_clears_registry(self): + register_framework("temp", FakeFramework()) + _reset_frameworks() + with pytest.raises(KeyError): + get_framework("temp") diff --git a/tests/test_langchain_llm_adapter.py b/tests/test_langchain_llm_adapter.py new file mode 100644 index 0000000000..de10b3fd4f --- /dev/null +++ b/tests/test_langchain_llm_adapter.py @@ -0,0 +1,348 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from langchain_core.messages import AIMessage + +from nemoguardrails.integrations.langchain.llm_adapter import ( + LangChainFramework, + LangChainLLMAdapter, + _langchain_chunk_to_llm_response_chunk, + _langchain_response_to_llm_response, +) +from nemoguardrails.types import ChatMessage, LLMModel, LLMResponse, LLMResponseChunk, ToolCall + + +class TestLangChainLLMAdapter: + def test_raw_llm_property(self): + mock_llm = MagicMock() + adapter = LangChainLLMAdapter(mock_llm) + assert adapter.raw_llm is mock_llm + + def test_model_name_property(self): + mock_llm = MagicMock() + mock_llm.model_name = "gpt-4" + adapter = LangChainLLMAdapter(mock_llm) + assert adapter.model_name == "gpt-4" + + def test_provider_name_property(self): + mock_llm = MagicMock() + mock_llm.__module__ = "langchain_openai.chat_models" + type(mock_llm).__module__ = "langchain_openai.chat_models" + adapter = LangChainLLMAdapter(mock_llm) + assert adapter.provider_name == "openai" + + def test_provider_url_from_base_url(self): + mock_llm = MagicMock(spec=[]) + mock_llm.base_url = "https://api.openai.com/v1" + adapter = LangChainLLMAdapter(mock_llm) + assert adapter.provider_url == "https://api.openai.com/v1" + + def test_provider_url_from_azure_endpoint(self): + mock_llm = MagicMock(spec=[]) + mock_llm.azure_endpoint = "https://example.openai.azure.com" + adapter = LangChainLLMAdapter(mock_llm) + assert adapter.provider_url == "https://example.openai.azure.com" + + def test_provider_url_from_server_url(self): + mock_llm = MagicMock(spec=[]) + mock_llm.server_url = "https://triton.example.com:8000" + adapter = LangChainLLMAdapter(mock_llm) + assert adapter.provider_url == "https://triton.example.com:8000" + + def test_provider_url_from_nested_client(self): + mock_llm = MagicMock(spec=[]) + mock_llm.client = MagicMock() + mock_llm.client.base_url = "https://custom.endpoint.com/v1" + adapter = LangChainLLMAdapter(mock_llm) + assert adapter.provider_url == "https://custom.endpoint.com/v1" + + def test_provider_url_returns_none_when_not_available(self): + mock_llm = MagicMock(spec=[]) + adapter = LangChainLLMAdapter(mock_llm) + assert adapter.provider_url is None + + @pytest.mark.asyncio + async def test_generate_with_string_prompt(self): + mock_llm = AsyncMock() + mock_llm.ainvoke.return_value = AIMessage(content="hello world") + adapter = LangChainLLMAdapter(mock_llm) + + result = await adapter.generate("say hello") + + assert isinstance(result, LLMResponse) + assert result.content == "hello world" + mock_llm.ainvoke.assert_called_once_with("say hello", stop=None) + + @pytest.mark.asyncio + async def test_generate_with_chat_message_list(self): + mock_llm = AsyncMock() + mock_llm.ainvoke.return_value = AIMessage(content="hi there") + adapter = LangChainLLMAdapter(mock_llm) + + result = await adapter.generate([ChatMessage.from_user("hello")]) + + assert isinstance(result, LLMResponse) + assert result.content == "hi there" + mock_llm.ainvoke.assert_called_once() + call_args = mock_llm.ainvoke.call_args + assert len(call_args[0][0]) == 1 + + @pytest.mark.asyncio + async def test_generate_returns_llm_response_with_reasoning(self): + mock_llm = AsyncMock() + mock_llm.ainvoke.return_value = AIMessage( + content="response", + additional_kwargs={"reasoning_content": "thinking"}, + response_metadata={"model": "gpt-4"}, + ) + adapter = LangChainLLMAdapter(mock_llm) + + result = await adapter.generate("prompt") + + assert isinstance(result, LLMResponse) + assert result.content == "response" + assert result.reasoning == "thinking" + assert result.model == "gpt-4" + + @pytest.mark.asyncio + async def test_generate_maps_tool_calls(self): + mock_llm = AsyncMock() + mock_llm.ainvoke.return_value = AIMessage( + content="", + tool_calls=[ + {"name": "search", "args": {"q": "weather"}, "id": "tc_1", "type": "tool_call"}, + ], + ) + adapter = LangChainLLMAdapter(mock_llm) + + result = await adapter.generate("prompt") + + assert result.tool_calls is not None + assert len(result.tool_calls) == 1 + assert isinstance(result.tool_calls[0], ToolCall) + assert result.tool_calls[0].function.name == "search" + assert result.tool_calls[0].function.arguments == {"q": "weather"} + assert result.tool_calls[0].id == "tc_1" + assert result.tool_calls[0].type == "function" + + @pytest.mark.asyncio + async def test_generate_maps_usage_metadata(self): + mock_llm = AsyncMock() + response = AIMessage(content="ok") + response.usage_metadata = {"total_tokens": 100, "input_tokens": 60, "output_tokens": 40} + mock_llm.ainvoke.return_value = response + adapter = LangChainLLMAdapter(mock_llm) + + result = await adapter.generate("prompt") + + assert result.usage is not None + assert result.usage.total_tokens == 100 + assert result.usage.input_tokens == 60 + assert result.usage.output_tokens == 40 + + @pytest.mark.asyncio + async def test_stream_yields_llm_response_chunks(self): + mock_llm = MagicMock() + chunks = [ + MagicMock(content="hel", response_metadata=None, usage_metadata=None, generation_info=None), + MagicMock(content="lo", response_metadata=None, usage_metadata=None, generation_info=None), + ] + + async def mock_astream(*args, **kwargs): + for c in chunks: + yield c + + mock_llm.astream = mock_astream + adapter = LangChainLLMAdapter(mock_llm) + + results = [] + async for chunk in adapter.stream("say hello"): + results.append(chunk) + + assert len(results) == 2 + assert all(isinstance(r, LLMResponseChunk) for r in results) + assert results[0].delta_content == "hel" + assert results[1].delta_content == "lo" + + @pytest.mark.asyncio + async def test_generate_passes_kwargs_via_bind(self): + mock_llm = MagicMock() + bound_llm = AsyncMock() + bound_llm.ainvoke.return_value = AIMessage(content="bound response") + mock_llm.bind.return_value = bound_llm + adapter = LangChainLLMAdapter(mock_llm) + + result = await adapter.generate("prompt", temperature=0.5, max_tokens=100) + + mock_llm.bind.assert_called_once_with(temperature=0.5, max_tokens=100) + bound_llm.ainvoke.assert_called_once_with("prompt", stop=None) + assert result.content == "bound response" + + @pytest.mark.asyncio + async def test_generate_no_kwargs_skips_bind(self): + mock_llm = AsyncMock() + mock_llm.ainvoke.return_value = AIMessage(content="direct response") + adapter = LangChainLLMAdapter(mock_llm) + + result = await adapter.generate("prompt") + + mock_llm.bind.assert_not_called() + assert result.content == "direct response" + + @pytest.mark.asyncio + async def test_generate_filters_reasoning_model_params(self): + mock_llm = MagicMock() + mock_llm.model_name = "o1-preview" + bound_llm = AsyncMock() + bound_llm.ainvoke.return_value = AIMessage(content="reasoning response") + mock_llm.bind.return_value = bound_llm + adapter = LangChainLLMAdapter(mock_llm) + + result = await adapter.generate("prompt", temperature=0.5, max_tokens=100) + + mock_llm.bind.assert_called_once_with(max_tokens=100) + assert result.content == "reasoning response" + + def test_satisfies_llm_model_protocol(self): + mock_llm = MagicMock() + mock_llm.model_name = "test" + adapter = LangChainLLMAdapter(mock_llm) + assert isinstance(adapter, LLMModel) + + +class TestLangChainFramework: + def test_framework_creates_model(self): + framework = LangChainFramework() + + mock_raw_llm = MagicMock() + with patch( + "nemoguardrails.llm.models.langchain_initializer.init_langchain_model", + return_value=mock_raw_llm, + ) as mock_init: + model = framework.create_model( + model_name="gpt-4", + provider_name="openai", + model_kwargs={"mode": "chat", "temperature": 0.5}, + ) + + assert isinstance(model, LangChainLLMAdapter) + assert model.raw_llm is mock_raw_llm + mock_init.assert_called_once_with( + model_name="gpt-4", + provider_name="openai", + mode="chat", + kwargs={"temperature": 0.5}, + ) + + +class TestFilterReasoningModelParams: + def _make_adapter(self, model_name): + mock_llm = MagicMock() + mock_llm.model_name = model_name + return LangChainLLMAdapter(mock_llm) + + @pytest.mark.parametrize( + "model,params,expected", + [ + ("gpt-4", {"temperature": 0.5, "max_tokens": 100}, {"temperature": 0.5, "max_tokens": 100}), + ("gpt-4o", {"temperature": 0.7}, {"temperature": 0.7}), + ("gpt-4o-mini", {"temperature": 0.3, "max_tokens": 50}, {"temperature": 0.3, "max_tokens": 50}), + ("gpt-5-chat", {"temperature": 0.5}, {"temperature": 0.5}), + ("o1-preview", {"temperature": 0.001, "max_tokens": 100}, {"max_tokens": 100}), + ("o1-mini", {"temperature": 0.5}, {}), + ("o3", {"temperature": 0.001, "max_tokens": 200}, {"max_tokens": 200}), + ("o3-mini", {"temperature": 0.1}, {}), + ("gpt-5", {"temperature": 0.001}, {}), + ("gpt-5-mini", {"temperature": 0.5, "max_tokens": 100}, {"max_tokens": 100}), + ("gpt-5-nano", {"temperature": 0.001}, {}), + ("o1-preview", {"max_tokens": 100}, {"max_tokens": 100}), + ("o1-preview", {}, {}), + ], + ) + def test_filter_params(self, model, params, expected): + adapter = self._make_adapter(model) + result = adapter._filter_reasoning_model_params(params) + assert result == expected + + def test_returns_none_when_params_is_none(self): + adapter = self._make_adapter("gpt-4") + result = adapter._filter_reasoning_model_params(None) + assert result is None + + def test_does_not_modify_original_params(self): + adapter = self._make_adapter("o1-preview") + params = {"temperature": 0.5, "max_tokens": 100} + adapter._filter_reasoning_model_params(params) + assert params == {"temperature": 0.5, "max_tokens": 100} + + +class TestConversionHelpers: + def test_langchain_response_to_llm_response_basic(self): + response = AIMessage(content="hello") + result = _langchain_response_to_llm_response(response) + + assert isinstance(result, LLMResponse) + assert result.content == "hello" + + def test_langchain_response_to_llm_response_with_tool_calls(self): + response = AIMessage( + content="", + tool_calls=[ + {"name": "fn", "args": {"x": 1}, "id": "tc1", "type": "tool_call"}, + ], + ) + result = _langchain_response_to_llm_response(response) + + assert result.tool_calls is not None + assert len(result.tool_calls) == 1 + assert result.tool_calls[0].function.name == "fn" + assert result.tool_calls[0].function.arguments == {"x": 1} + + def test_langchain_response_fallback_to_str(self): + result = _langchain_response_to_llm_response("plain text") + + assert isinstance(result, LLMResponse) + assert result.content == "plain text" + + def test_langchain_chunk_to_llm_response_chunk_with_content(self): + chunk = MagicMock() + chunk.content = "hello" + chunk.response_metadata = {"finish_reason": "stop"} + chunk.usage_metadata = None + chunk.generation_info = None + + result = _langchain_chunk_to_llm_response_chunk(chunk) + + assert isinstance(result, LLMResponseChunk) + assert result.delta_content == "hello" + assert result.provider_metadata is not None + assert result.provider_metadata["finish_reason"] == "stop" + + def test_langchain_chunk_to_llm_response_chunk_with_text(self): + chunk = MagicMock(spec=[]) + chunk.text = "world" + + result = _langchain_chunk_to_llm_response_chunk(chunk) + + assert result.delta_content == "world" + + def test_langchain_chunk_to_llm_response_chunk_fallback_str(self): + result = _langchain_chunk_to_llm_response_chunk("raw string") + + assert isinstance(result, LLMResponseChunk) + assert result.delta_content == "raw string" From 3982772aa8adc9910c168d77a518a1fa4773c543 Mon Sep 17 00:00:00 2001 From: Pouyanpi <13303554+Pouyanpi@users.noreply.github.com> Date: Thu, 2 Apr 2026 15:38:17 +0200 Subject: [PATCH 2/2] fix(llm): address review feedback on adapter and framework registry - Raise ValueError for unsupported ChatMessage roles - Return "community" for 2-segment langchain_community paths - Log debug when temperature is stripped for reasoning models - Validate set_default_framework against known frameworks - Fall back to input+output for missing total_tokens - Use builtin dict in isinstance check --- .../integrations/langchain/llm_adapter.py | 23 +++++++++++-------- .../integrations/langchain/message_utils.py | 2 +- nemoguardrails/llm/frameworks.py | 5 ++++ tests/llm/test_frameworks.py | 5 ++++ 4 files changed, 24 insertions(+), 11 deletions(-) diff --git a/nemoguardrails/integrations/langchain/llm_adapter.py b/nemoguardrails/integrations/langchain/llm_adapter.py index a962cd7a08..856d3849bd 100644 --- a/nemoguardrails/integrations/langchain/llm_adapter.py +++ b/nemoguardrails/integrations/langchain/llm_adapter.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import uuid from typing import Any, AsyncIterator, Dict, List, Optional, Union @@ -27,6 +28,8 @@ UsageInfo, ) +log = logging.getLogger(__name__) + def _infer_model_name(llm: Any): """Helper to infer the model name based from an LLM instance. @@ -41,7 +44,7 @@ def _infer_model_name(llm: Any): return val model_kwargs = getattr(llm, "model_kwargs", None) - if model_kwargs and isinstance(model_kwargs, Dict): + if model_kwargs and isinstance(model_kwargs, dict): for attr in ["model", "model_name", "name"]: val = model_kwargs.get(attr) if isinstance(val, str): @@ -77,9 +80,7 @@ def _infer_provider_from_module(llm: Any) -> Optional[str]: if provider == "community": parts = module.split(".") - if len(parts) >= 3: - provider = parts[-1] - return provider + return parts[-1] if len(parts) >= 3 else "community" else: return provider @@ -91,9 +92,7 @@ def _infer_provider_from_module(llm: Any) -> Optional[str]: if provider == "community": parts = base_module.split(".") - if len(parts) >= 3: - provider = parts[-1] - return provider + return parts[-1] if len(parts) >= 3 else "community" else: return provider @@ -156,6 +155,7 @@ def _filter_reasoning_model_params(self, params: Optional[dict]) -> Optional[dic if is_openai_reasoning_model: filtered = params.copy() filtered.pop("temperature", None) + log.debug("Stripped 'temperature' for reasoning model '%s'", model_name) return filtered return params @@ -251,10 +251,13 @@ def _build_usage_info(raw: Any) -> Optional[UsageInfo]: return None if not raw: return None + input_tokens = raw.get("input_tokens", raw.get("prompt_tokens", 0)) + output_tokens = raw.get("output_tokens", raw.get("completion_tokens", 0)) + total_tokens = raw.get("total_tokens") or (input_tokens + output_tokens) return UsageInfo( - input_tokens=raw.get("input_tokens", raw.get("prompt_tokens", 0)), - output_tokens=raw.get("output_tokens", raw.get("completion_tokens", 0)), - total_tokens=raw.get("total_tokens", 0), + input_tokens=input_tokens, + output_tokens=output_tokens, + total_tokens=total_tokens, reasoning_tokens=raw.get("reasoning_tokens"), cached_tokens=raw.get("cached_tokens", raw.get("cache_read_input_tokens")), ) diff --git a/nemoguardrails/integrations/langchain/message_utils.py b/nemoguardrails/integrations/langchain/message_utils.py index cec37b0485..9cb5cbc7bc 100644 --- a/nemoguardrails/integrations/langchain/message_utils.py +++ b/nemoguardrails/integrations/langchain/message_utils.py @@ -170,7 +170,7 @@ def chatmessage_to_langchain_message(msg: "ChatMessage") -> BaseMessage: for tc in msg.tool_calls ] return AIMessage(content=content, **kwargs) - return HumanMessage(content=content) + raise ValueError(f"Unsupported ChatMessage role: {msg.role}") def chatmessages_to_langchain_messages(msgs: List["ChatMessage"]) -> List[BaseMessage]: diff --git a/nemoguardrails/llm/frameworks.py b/nemoguardrails/llm/frameworks.py index 143dd034c0..c7c62ce062 100644 --- a/nemoguardrails/llm/frameworks.py +++ b/nemoguardrails/llm/frameworks.py @@ -40,7 +40,12 @@ def get_framework(name: str) -> LLMFramework: return _frameworks[name] +_LAZY_FRAMEWORKS = {"langchain"} + + def set_default_framework(name: str) -> None: + if name not in _frameworks and name not in _LAZY_FRAMEWORKS: + raise KeyError(f"Unknown framework '{name}'. Register it first or use one of: {sorted(_LAZY_FRAMEWORKS)}") global _default_framework _default_framework = name diff --git a/tests/llm/test_frameworks.py b/tests/llm/test_frameworks.py index fe33db91a9..c76416ccd6 100644 --- a/tests/llm/test_frameworks.py +++ b/tests/llm/test_frameworks.py @@ -60,9 +60,14 @@ def test_langchain_lazy_auto_registration(self): assert isinstance(fw, LangChainFramework) def test_set_and_get_default_framework(self): + register_framework("custom", FakeFramework()) set_default_framework("custom") assert get_default_framework() == "custom" + def test_set_default_unknown_raises(self): + with pytest.raises(KeyError, match="Unknown framework"): + set_default_framework("nonexistent") + def test_default_is_langchain(self): assert get_default_framework() == "langchain"