Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 198 additions & 0 deletions nemoguardrails/guardrails/rail_action.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
# SPDX-FileCopyrightText: Copyright (c) 2023-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Base class for IORails rail actions.

Defines the template-method pipeline: validate → extract → prompt → respond → parse.
Subclasses override individual steps. The base provides three concrete response
helpers for the common call patterns (LLM, API, local).
"""

from __future__ import annotations

import logging
from abc import ABC, abstractmethod
from typing import Any, Optional, Union

from nemoguardrails.guardrails.guardrails_types import (
LLMMessages,
RailResult,
get_request_id,
truncate,
)
from nemoguardrails.guardrails.model_manager import ModelManager
from nemoguardrails.llm.taskmanager import LLMTaskManager
from nemoguardrails.rails.llm.config import _get_flow_model, _get_flow_name

log = logging.getLogger(__name__)


class RailAction(ABC):
"""Base class for all IORails rail actions.

Subclasses implement the abstract ``_``-prefixed hooks to customise each
stage of the pipeline. The public entry point is :meth:`run`.

Subclasses must define the class attribute:
- action_name: The base flow name as it appears in RailsConfig
(e.g. ``"content safety check input"``).
"""

action_name: str

def __init__(self, model_manager: ModelManager, task_manager: LLMTaskManager) -> None:
self.model_manager = model_manager
self.task_manager = task_manager

# ------------------------------------------------------------------
# Public entry point (template method)
# ------------------------------------------------------------------

async def run(
self,
flow: str,
messages: LLMMessages,
bot_response: Optional[str] = None,
) -> RailResult:
"""Execute the full rail pipeline and return a safety result."""
req_id = get_request_id()
base_flow = _get_flow_name(flow)
model_type = _get_flow_model(flow)

self._validate_input(flow, messages, bot_response)

extracted = self._extract_messages(messages, bot_response)
log.debug("[%s] %s extracted: %s", req_id, base_flow, truncate(extracted))

prompt = self._create_prompt(flow, extracted)
if prompt is not None:
log.debug("[%s] %s prompt: %s", req_id, base_flow, truncate(prompt))

try:
response = await self._get_response(flow, prompt, model_type)
log.debug("[%s] %s response: %s", req_id, base_flow, truncate(response))
return self._parse_response(response)
except Exception as e:
log.error("[%s] %s failed: %s", req_id, base_flow, e)
return RailResult(is_safe=False, reason=f"{base_flow} error: {e}")

# ------------------------------------------------------------------
# Abstract hooks — subclasses must implement these
# ------------------------------------------------------------------

@abstractmethod
def _validate_input(
self,
flow: str,
messages: LLMMessages,
bot_response: Optional[str],
) -> None:
"""Raise if inputs are invalid (e.g. missing $model= parameter)."""

@abstractmethod
def _extract_messages(
self,
messages: LLMMessages,
bot_response: Optional[str],
) -> dict[str, Any]:
"""Extract the relevant fields from messages into a dict.

Returns a dict of extracted values that will be passed to _create_prompt.
"""

@abstractmethod
def _create_prompt(
self,
flow: str,
extracted: dict[str, Any],
) -> Any:
"""Build the prompt / request payload from extracted data.

Returns whatever _get_response needs: a message list, a dict body, etc.
May return None if the response step doesn't need a prompt (e.g. API calls
that build their own payload).
"""

@abstractmethod
async def _get_response(
self,
flow: str,
prompt: Any,
model_type: Optional[str],
) -> Any:
"""Call the model/API/local engine and return the raw response."""

@abstractmethod
def _parse_response(self, response: Any) -> RailResult:
"""Convert the raw response into a RailResult."""

# ------------------------------------------------------------------
# Concrete response helpers — subclasses call these from _get_response
# ------------------------------------------------------------------

async def _get_llm_response(
self,
model_type: str,
messages: list[dict],
**kwargs: Any,
) -> str:
"""Call an LLM via ModelManager and return the response text."""
return await self.model_manager.generate_async(model_type, messages, **kwargs)

async def _get_api_response(
self,
api_name: str,
body: dict[str, Any],
**kwargs: Any,
) -> dict[str, Any]:
"""Call an API endpoint via ModelManager and return the response dict."""
return await self.model_manager.api_call(api_name, body, **kwargs)

async def _get_local_response(self, **kwargs: Any) -> Any:
"""Run a local/in-process check. Override in subclasses that need it."""
raise NotImplementedError("Subclass must override _get_local_response")

# ------------------------------------------------------------------
# Shared utilities
# ------------------------------------------------------------------

def _validate_flow_name(self, flow: str) -> None:
"""Verify the flow's base name matches this action's action_name."""
base_flow = _get_flow_name(flow)
if base_flow != self.action_name:
raise RuntimeError(f"Flow '{base_flow}' does not match expected action_name '{self.action_name}'")

@staticmethod
def _last_user_content(messages: LLMMessages) -> str:
"""Return the content of the last user message."""
for msg in reversed(messages):
if msg.get("role") == "user" and msg.get("content"):
return msg["content"]
raise RuntimeError(f"No user message found in: {messages}")

@staticmethod
def _require_model_type(flow: str) -> str:
"""Extract the $model=<type> from a flow string, or raise."""
model_type = _get_flow_model(flow)
if not model_type:
raise RuntimeError(f"No $model= specified in flow: {flow}")
return model_type

@staticmethod
def _prompt_to_messages(prompt: Union[str, list[dict]]) -> list[dict]:
"""Convert LLMTaskManager render output to role/content message format."""
if isinstance(prompt, str):
return [{"role": "user", "content": prompt}]
return [{"role": m["type"], "content": m["content"]} for m in prompt]
Loading
Loading