diff --git a/asyncua/common/structures104.py b/asyncua/common/structures104.py index 78c7d6322..0138e5b3e 100644 --- a/asyncua/common/structures104.py +++ b/asyncua/common/structures104.py @@ -297,7 +297,7 @@ def make_structure( else: default_val = None - fields.append((fname, uatype, default_val)) + fields.append((fname, prop_uatype, default_val)) namespace = { "ua": ua, @@ -437,6 +437,20 @@ async def get_children_descriptions_type_definitions( idxs = [] for idx, desc in enumerate(descs): if hasattr(ua, desc.BrowseName.Name) and not overwrite_existing: + existing = getattr(ua, desc.BrowseName.Name) + existing_dtype = getattr(existing, "data_type", None) + if isinstance(existing_dtype, ua.NodeId) and existing_dtype != desc.NodeId: + _logger.warning( + "DataType name collision for %s: existing=%s, discovered=%s. Skipping discovered type because overwrite_existing is False.", + desc.BrowseName.Name, + existing_dtype, + desc.NodeId, + ) + _logger.warning( + "Skipping DataType %s (%s) because class already exists and overwrite_existing is False.", + desc.BrowseName.Name, + desc.NodeId, + ) continue idxs.append(idx) nodes.append(server.get_node(desc.NodeId)) diff --git a/asyncua/ua/ua_binary.py b/asyncua/ua/ua_binary.py index aa8be8ee7..8bdc05931 100644 --- a/asyncua/ua/ua_binary.py +++ b/asyncua/ua/ua_binary.py @@ -8,9 +8,10 @@ import functools import logging import struct +import sys import typing import uuid -from collections.abc import Callable, Iterable, Sequence +from collections.abc import Callable, Iterable, Mapping, Sequence from dataclasses import Field, fields, is_dataclass from enum import Enum, IntFlag from io import BytesIO @@ -46,16 +47,20 @@ def set_string_encoding(new_encoding: str) -> None: _string_encoding.set(new_encoding) -def get_safe_type_hints(cls: type, extra_ns: dict[str, Any] | None = None) -> dict[str, Any]: - # Use globalns=None so that get_type_hints automatically resolves the - # module globals of cls (e.g. bare names like Byte). - # Pass extra_ns (e.g. {'ua': ua}) as localns so ua.Xxx annotations resolve too. - # Filter out properties from the class dict to avoid shadowing. +def get_safe_type_hints(cls: type[Any], extra_ns: Mapping[str, Any] | None = None) -> dict[str, Any]: + # Resolve annotations with explicit module globals for stable behavior + # across Python versions (notably 3.10 forward-reference handling). + module = sys.modules.get(cls.__module__) + globalns = vars(module).copy() if module is not None else {} + if extra_ns: + globalns.update(extra_ns) + + # Keep class-local names available and avoid property shadowing. localns = {k: v for k, v in cls.__dict__.items() if not isinstance(v, property)} if extra_ns: localns.update(extra_ns) - return typing.get_type_hints(cls, globalns=None, localns=localns) + return typing.get_type_hints(cls, globalns=globalns, localns=localns) def test_bit(data: int, offset: int) -> int: @@ -327,9 +332,35 @@ def resolve_uatype(ftype: Any) -> tuple[Any, bool]: return ftype, is_optional +def _resolve_type_in_dataclass_context(ftype: Any, dataclazz: type) -> Any: + if not isinstance(ftype, str): + return ftype + + module = sys.modules.get(getattr(dataclazz, "__module__", "")) if isinstance(dataclazz, type) else None + namespace = { + "ua": ua, + "typing": typing, + "list": list, + "List": list, + "Union": typing.Union, + "Optional": typing.Optional, + "Dict": dict, + } + if module is not None: + namespace.update(vars(module)) + if isinstance(dataclazz, type): + namespace.update({k: v for k, v in dataclazz.__dict__.items() if not isinstance(v, property)}) + + try: + return eval(ftype, namespace) + except Exception: + return ftype + + def field_serializer(uatype: Any, is_optional: bool, dataclazz: type) -> Callable[[Any], bytes]: if type_is_list(uatype): ft = type_from_list(uatype) + ft = _resolve_type_in_dataclass_context(ft, dataclazz) if is_optional: return lambda val: b"" if val is None else create_list_serializer(ft, ft == dataclazz)(val) return create_list_serializer(ft, ft == dataclazz) @@ -460,21 +491,15 @@ def create_list_serializer(uatype: type, recursive: bool = False) -> Callable[[S data_type = getattr(Primitives1, uatype.__name__) return data_type.pack_array none_val = Primitives.Int32.pack(-1) - if recursive: - - def recursive_serialize(val: Sequence[Any] | None) -> bytes: - if val is None: - return none_val - data_size = Primitives.Int32.pack(len(val)) - return data_size + b"".join(create_type_serializer(uatype)(el) for el in val) - - return recursive_serialize - type_serializer = create_type_serializer(uatype) + type_serializer = None def serialize(val: Sequence[Any] | None) -> bytes: + nonlocal type_serializer if val is None: return none_val + if type_serializer is None: + type_serializer = create_type_serializer(uatype) data_size = Primitives.Int32.pack(len(val)) return data_size + b"".join(type_serializer(el) for el in val) @@ -624,7 +649,10 @@ def extensionobject_from_binary(data: Buffer) -> Any: cls = ua.extension_objects_by_typeid[typeid] if body is None: raise UaError(f"parsing ExtensionObject {cls.__name__} without data") - return from_binary(cls, body) + try: + return from_binary(cls, body) + except Exception as exc: + raise UaError(f"Error decoding ExtensionObject {cls.__name__} with TypeId {typeid}") from exc if body is not None: body_data = body.read(len(body)) else: @@ -662,25 +690,23 @@ def extensionobject_to_binary(obj: Any) -> bytes: @functools.cache -def _create_list_deserializer(uatype: type, recursive: bool = False) -> Callable[[Buffer | IO], list[Any]]: - if recursive: +def _create_list_deserializer(uatype: Any, recursive: bool = False) -> Callable[[Buffer | IO[Any]], list[Any]]: + # Resolve the element decoder lazily so mutually-recursive dataclass lists + # do not recurse forever during deserializer construction. + element_deserializer = None - def _deserialize_recursive(data: Buffer | IO) -> list[Any]: - size = Primitives.Int32.unpack(data) - return [_create_type_deserializer(uatype, type(None))(data) for _ in range(size)] - - return _deserialize_recursive - element_deserializer = _create_type_deserializer(uatype, type(None)) - - def _deserialize(data: Buffer | IO) -> list[Any]: + def _deserialize(data: Buffer | IO[Any]) -> list[Any]: + nonlocal element_deserializer size = Primitives.Int32.unpack(data) + if element_deserializer is None: + element_deserializer = _create_type_deserializer(uatype, type(None)) return [element_deserializer(data) for _ in range(size)] return _deserialize @functools.cache -def _create_type_deserializer(uatype: Any, dataclazz: type) -> Callable[[Buffer | IO], Any]: +def _create_type_deserializer(uatype: Any, dataclazz: type) -> Callable[[Buffer | IO[Any]], Any]: uatype, is_optional = resolve_uatype(uatype) if not is_optional and type_is_union(uatype): @@ -689,6 +715,7 @@ def _create_type_deserializer(uatype: Any, dataclazz: type) -> Callable[[Buffer return _create_type_deserializer(uatype, uatype) if type_is_list(uatype): utype = type_from_list(uatype) + utype = _resolve_type_in_dataclass_context(utype, dataclazz) if hasattr(ua.VariantType, utype.__name__): vtype = getattr(ua.VariantType, utype.__name__) return _create_uatype_array_deserializer(vtype) @@ -767,10 +794,17 @@ def decode(data: Buffer | IO) -> Any: kwargs: dict[str, Any] = {} enc: int = 0 for field, optional_enc_bit, deserialize_field in dc_field_deserializers: - if field.name == "Encoding": - enc = deserialize_field(data) - elif optional_enc_bit == 0 or enc & optional_enc_bit: - kwargs[field.name] = deserialize_field(data) + try: + if field.name == "Encoding": + enc = deserialize_field(data) + elif optional_enc_bit == 0 or enc & optional_enc_bit: + kwargs[field.name] = deserialize_field(data) + except Exception as exc: + remaining = len(data) if hasattr(data, "__len__") else "unknown" + raise UaError( + f"Error decoding field {objtype.__name__}.{field.name} " + f"(encoding={enc}, optional_bit={optional_enc_bit}, remaining={remaining})" + ) from exc return objtype(**kwargs) return decode diff --git a/tests/test_unit.py b/tests/test_unit.py index 71d941f7b..f0ea6b991 100644 --- a/tests/test_unit.py +++ b/tests/test_unit.py @@ -17,6 +17,7 @@ from asyncua.common.connection import MessageChunk from asyncua.common.event_objects import BaseEvent from asyncua.common.structures import StructGenerator +from asyncua.common.structures104 import make_structure from asyncua.common.ua_utils import string_to_val, val_to_string from asyncua.crypto.security_policies import SecurityPolicyNone from asyncua.server.monitored_item_service import WhereClauseEvaluator @@ -37,6 +38,18 @@ EXAMPLE_BSD_PATH = Path(__file__).parent.absolute() / "example.bsd" +@dataclass +class _MutualRecursiveChild: + Name: ua.String = "" + Parents: list["_MutualRecursiveParent"] = field(default_factory=list) + + +@dataclass +class _MutualRecursiveParent: + Name: ua.String = "" + Children: list[_MutualRecursiveChild] = field(default_factory=list) + + def test_variant_array_none(): v = ua.Variant(None, VariantType=ua.VariantType.Int32, is_array=True) data = variant_to_binary(v) @@ -910,6 +923,44 @@ class MyStruct: assert m == m2 +def test_struct104_optional_field_respects_encoding_mask() -> None: + sdef = ua.StructureDefinition() + sdef.StructureType = ua.StructureType.StructureWithOptionalFields + + desc = ua.StructureField() + desc.Name = "Description" + desc.DataType = ua.NodeId(ua.ObjectIds.LocalizedText) + desc.IsOptional = True + desc.ValueRank = ua.ValueRank.OneDimension + desc.ArrayDimensions = [0] + sdef.Fields = [desc] + + cls = make_structure(ua.NodeId(65001, 2), "_OptionalFieldMaskStruct", sdef)["_OptionalFieldMaskStruct"] + + # Encoding mask without optional bits set: optional field payload must be absent. + data = ua.ua_binary.Primitives.UInt32.pack(0) + decoded = struct_from_binary(cls, ua.utils.Buffer(data)) + + assert decoded.Description is None + + +def test_struct_mutual_recursive_lists_roundtrip() -> None: + root = _MutualRecursiveParent(Name="root") + child = _MutualRecursiveChild(Name="leaf") + branch = _MutualRecursiveParent(Name="branch") + root.Children.append(child) + child.Parents.append(branch) + + data = struct_to_binary(root) + decoded = struct_from_binary(_MutualRecursiveParent, ua.utils.Buffer(data)) + + assert decoded.Name == "root" + assert len(decoded.Children) == 1 + assert decoded.Children[0].Name == "leaf" + assert len(decoded.Children[0].Parents) == 1 + assert decoded.Children[0].Parents[0].Name == "branch" + + def test_session_security_diagnostics_roundtrip(): """Regression test: SessionSecurityDiagnosticsDataType has a bare 'Encoding: Byte' annotation (not quoted as 'ua.Byte'). With