Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion asyncua/common/structures104.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def make_structure(
else:
default_val = None

fields.append((fname, uatype, default_val))
fields.append((fname, prop_uatype, default_val))

namespace = {
"ua": ua,
Expand Down Expand Up @@ -437,6 +437,20 @@ async def get_children_descriptions_type_definitions(
idxs = []
for idx, desc in enumerate(descs):
if hasattr(ua, desc.BrowseName.Name) and not overwrite_existing:
existing = getattr(ua, desc.BrowseName.Name)
existing_dtype = getattr(existing, "data_type", None)
if isinstance(existing_dtype, ua.NodeId) and existing_dtype != desc.NodeId:
_logger.warning(
"DataType name collision for %s: existing=%s, discovered=%s. Skipping discovered type because overwrite_existing is False.",
desc.BrowseName.Name,
existing_dtype,
desc.NodeId,
)
_logger.warning(
"Skipping DataType %s (%s) because class already exists and overwrite_existing is False.",
desc.BrowseName.Name,
desc.NodeId,
)
continue
idxs.append(idx)
nodes.append(server.get_node(desc.NodeId))
Expand Down
100 changes: 67 additions & 33 deletions asyncua/ua/ua_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
import functools
import logging
import struct
import sys
import typing
import uuid
from collections.abc import Callable, Iterable, Sequence
from collections.abc import Callable, Iterable, Mapping, Sequence
from dataclasses import Field, fields, is_dataclass
from enum import Enum, IntFlag
from io import BytesIO
Expand Down Expand Up @@ -46,16 +47,20 @@ def set_string_encoding(new_encoding: str) -> None:
_string_encoding.set(new_encoding)


def get_safe_type_hints(cls: type, extra_ns: dict[str, Any] | None = None) -> dict[str, Any]:
# Use globalns=None so that get_type_hints automatically resolves the
# module globals of cls (e.g. bare names like Byte).
# Pass extra_ns (e.g. {'ua': ua}) as localns so ua.Xxx annotations resolve too.
# Filter out properties from the class dict to avoid shadowing.
def get_safe_type_hints(cls: type[Any], extra_ns: Mapping[str, Any] | None = None) -> dict[str, Any]:
# Resolve annotations with explicit module globals for stable behavior
# across Python versions (notably 3.10 forward-reference handling).
module = sys.modules.get(cls.__module__)
globalns = vars(module).copy() if module is not None else {}
if extra_ns:
globalns.update(extra_ns)

# Keep class-local names available and avoid property shadowing.
localns = {k: v for k, v in cls.__dict__.items() if not isinstance(v, property)}
if extra_ns:
localns.update(extra_ns)

return typing.get_type_hints(cls, globalns=None, localns=localns)
return typing.get_type_hints(cls, globalns=globalns, localns=localns)


def test_bit(data: int, offset: int) -> int:
Expand Down Expand Up @@ -327,9 +332,35 @@ def resolve_uatype(ftype: Any) -> tuple[Any, bool]:
return ftype, is_optional


def _resolve_type_in_dataclass_context(ftype: Any, dataclazz: type) -> Any:
if not isinstance(ftype, str):
return ftype

module = sys.modules.get(getattr(dataclazz, "__module__", "")) if isinstance(dataclazz, type) else None
namespace = {
"ua": ua,
"typing": typing,
"list": list,
"List": list,
"Union": typing.Union,
"Optional": typing.Optional,
"Dict": dict,
}
if module is not None:
namespace.update(vars(module))
if isinstance(dataclazz, type):
namespace.update({k: v for k, v in dataclazz.__dict__.items() if not isinstance(v, property)})

try:
return eval(ftype, namespace)
except Exception:
return ftype


def field_serializer(uatype: Any, is_optional: bool, dataclazz: type) -> Callable[[Any], bytes]:
if type_is_list(uatype):
ft = type_from_list(uatype)
ft = _resolve_type_in_dataclass_context(ft, dataclazz)
if is_optional:
return lambda val: b"" if val is None else create_list_serializer(ft, ft == dataclazz)(val)
return create_list_serializer(ft, ft == dataclazz)
Expand Down Expand Up @@ -460,21 +491,15 @@ def create_list_serializer(uatype: type, recursive: bool = False) -> Callable[[S
data_type = getattr(Primitives1, uatype.__name__)
return data_type.pack_array
none_val = Primitives.Int32.pack(-1)
if recursive:

def recursive_serialize(val: Sequence[Any] | None) -> bytes:
if val is None:
return none_val
data_size = Primitives.Int32.pack(len(val))
return data_size + b"".join(create_type_serializer(uatype)(el) for el in val)

return recursive_serialize

type_serializer = create_type_serializer(uatype)
type_serializer = None

def serialize(val: Sequence[Any] | None) -> bytes:
nonlocal type_serializer
if val is None:
return none_val
if type_serializer is None:
type_serializer = create_type_serializer(uatype)
data_size = Primitives.Int32.pack(len(val))
return data_size + b"".join(type_serializer(el) for el in val)

Expand Down Expand Up @@ -624,7 +649,10 @@ def extensionobject_from_binary(data: Buffer) -> Any:
cls = ua.extension_objects_by_typeid[typeid]
if body is None:
raise UaError(f"parsing ExtensionObject {cls.__name__} without data")
return from_binary(cls, body)
try:
return from_binary(cls, body)
except Exception as exc:
raise UaError(f"Error decoding ExtensionObject {cls.__name__} with TypeId {typeid}") from exc
if body is not None:
body_data = body.read(len(body))
else:
Expand Down Expand Up @@ -662,25 +690,23 @@ def extensionobject_to_binary(obj: Any) -> bytes:


@functools.cache
def _create_list_deserializer(uatype: type, recursive: bool = False) -> Callable[[Buffer | IO], list[Any]]:
if recursive:
def _create_list_deserializer(uatype: Any, recursive: bool = False) -> Callable[[Buffer | IO[Any]], list[Any]]:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not follow everything here but that typing change looks very strange. The entire point of that method is to take binary data and transform it into ojjects... hwy would it take Any??

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

somehow typechecking blamed it... i looked it up:

Why not just IO?
In Python typing, IO is generic. Writing bare IO is considered incomplete by type checkers. IO[Any] means “an IO object of unspecified content type,” which is practical here because many real stream objects are not strictly annotated as IO[bytes] even though they return bytes at runtime.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry somehow I missed the IO[] part ...

# Resolve the element decoder lazily so mutually-recursive dataclass lists
# do not recurse forever during deserializer construction.
element_deserializer = None

def _deserialize_recursive(data: Buffer | IO) -> list[Any]:
size = Primitives.Int32.unpack(data)
return [_create_type_deserializer(uatype, type(None))(data) for _ in range(size)]

return _deserialize_recursive
element_deserializer = _create_type_deserializer(uatype, type(None))

def _deserialize(data: Buffer | IO) -> list[Any]:
def _deserialize(data: Buffer | IO[Any]) -> list[Any]:
nonlocal element_deserializer
size = Primitives.Int32.unpack(data)
if element_deserializer is None:
element_deserializer = _create_type_deserializer(uatype, type(None))
return [element_deserializer(data) for _ in range(size)]

return _deserialize


@functools.cache
def _create_type_deserializer(uatype: Any, dataclazz: type) -> Callable[[Buffer | IO], Any]:
def _create_type_deserializer(uatype: Any, dataclazz: type) -> Callable[[Buffer | IO[Any]], Any]:
uatype, is_optional = resolve_uatype(uatype)

if not is_optional and type_is_union(uatype):
Expand All @@ -689,6 +715,7 @@ def _create_type_deserializer(uatype: Any, dataclazz: type) -> Callable[[Buffer
return _create_type_deserializer(uatype, uatype)
if type_is_list(uatype):
utype = type_from_list(uatype)
utype = _resolve_type_in_dataclass_context(utype, dataclazz)
if hasattr(ua.VariantType, utype.__name__):
vtype = getattr(ua.VariantType, utype.__name__)
return _create_uatype_array_deserializer(vtype)
Expand Down Expand Up @@ -767,10 +794,17 @@ def decode(data: Buffer | IO) -> Any:
kwargs: dict[str, Any] = {}
enc: int = 0
for field, optional_enc_bit, deserialize_field in dc_field_deserializers:
if field.name == "Encoding":
enc = deserialize_field(data)
elif optional_enc_bit == 0 or enc & optional_enc_bit:
kwargs[field.name] = deserialize_field(data)
try:
if field.name == "Encoding":
enc = deserialize_field(data)
elif optional_enc_bit == 0 or enc & optional_enc_bit:
kwargs[field.name] = deserialize_field(data)
except Exception as exc:
remaining = len(data) if hasattr(data, "__len__") else "unknown"
raise UaError(
f"Error decoding field {objtype.__name__}.{field.name} "
f"(encoding={enc}, optional_bit={optional_enc_bit}, remaining={remaining})"
) from exc
return objtype(**kwargs)

return decode
Expand Down
51 changes: 51 additions & 0 deletions tests/test_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from asyncua.common.connection import MessageChunk
from asyncua.common.event_objects import BaseEvent
from asyncua.common.structures import StructGenerator
from asyncua.common.structures104 import make_structure
from asyncua.common.ua_utils import string_to_val, val_to_string
from asyncua.crypto.security_policies import SecurityPolicyNone
from asyncua.server.monitored_item_service import WhereClauseEvaluator
Expand All @@ -37,6 +38,18 @@
EXAMPLE_BSD_PATH = Path(__file__).parent.absolute() / "example.bsd"


@dataclass
class _MutualRecursiveChild:
Name: ua.String = ""
Parents: list["_MutualRecursiveParent"] = field(default_factory=list)


@dataclass
class _MutualRecursiveParent:
Name: ua.String = ""
Children: list[_MutualRecursiveChild] = field(default_factory=list)


def test_variant_array_none():
v = ua.Variant(None, VariantType=ua.VariantType.Int32, is_array=True)
data = variant_to_binary(v)
Expand Down Expand Up @@ -910,6 +923,44 @@ class MyStruct:
assert m == m2


def test_struct104_optional_field_respects_encoding_mask() -> None:
sdef = ua.StructureDefinition()
sdef.StructureType = ua.StructureType.StructureWithOptionalFields

desc = ua.StructureField()
desc.Name = "Description"
desc.DataType = ua.NodeId(ua.ObjectIds.LocalizedText)
desc.IsOptional = True
desc.ValueRank = ua.ValueRank.OneDimension
desc.ArrayDimensions = [0]
sdef.Fields = [desc]

cls = make_structure(ua.NodeId(65001, 2), "_OptionalFieldMaskStruct", sdef)["_OptionalFieldMaskStruct"]

# Encoding mask without optional bits set: optional field payload must be absent.
data = ua.ua_binary.Primitives.UInt32.pack(0)
decoded = struct_from_binary(cls, ua.utils.Buffer(data))

assert decoded.Description is None


def test_struct_mutual_recursive_lists_roundtrip() -> None:
root = _MutualRecursiveParent(Name="root")
child = _MutualRecursiveChild(Name="leaf")
branch = _MutualRecursiveParent(Name="branch")
root.Children.append(child)
child.Parents.append(branch)

data = struct_to_binary(root)
decoded = struct_from_binary(_MutualRecursiveParent, ua.utils.Buffer(data))

assert decoded.Name == "root"
assert len(decoded.Children) == 1
assert decoded.Children[0].Name == "leaf"
assert len(decoded.Children[0].Parents) == 1
assert decoded.Children[0].Parents[0].Name == "branch"


def test_session_security_diagnostics_roundtrip():
"""Regression test: SessionSecurityDiagnosticsDataType has a bare
'Encoding: Byte' annotation (not quoted as 'ua.Byte'). With
Expand Down
Loading