Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/docstub-stubs/_stubs.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import enum
import logging
from collections.abc import Sequence
from dataclasses import dataclass
from functools import wraps
from pathlib import Path
Expand Down Expand Up @@ -34,6 +35,9 @@ class ScopeType(enum.StrEnum):
CLASSMETHOD = "classmethod"
STATICMETHOD = "staticmethod"

_dataclass_name: cstm.Name
_dataclass_matcher: cstm.ClassDef

@dataclass(slots=True, frozen=True)
class _Scope:

Expand All @@ -51,7 +55,7 @@ class _Scope:

def _get_docstring_node(
node: cst.FunctionDef | cst.ClassDef | cst.Module,
) -> cst.SimpleString | cst.ConcatenatedString | None: ...
) -> tuple[cst.SimpleString | cst.ConcatenatedString | None, str | None]: ...
def _log_error_with_line_context(cls: Py2StubTransformer) -> Py2StubTransformer: ...
def _docstub_comment_directives(cls: Py2StubTransformer) -> Py2StubTransformer: ...
def _inline_node_as_code(node: cst.CSTNode) -> str: ...
Expand Down
73 changes: 48 additions & 25 deletions src/docstub/_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import enum
import logging
from collections.abc import Sequence
from dataclasses import dataclass
from functools import wraps
from typing import ClassVar
Expand Down Expand Up @@ -54,6 +55,22 @@ class ScopeType(enum.StrEnum):
# docstub: on


# To be used with `libcst.matchers.matches()` to guess if a node is a dataclass
# See `test_dataclass_matcher` for supported cases
_dataclass_name: cstm.Name = cstm.Name("dataclass")
_dataclass_matcher: cstm.ClassDef = cstm.ClassDef(
decorators=[
cstm.Decorator(
decorator=(
_dataclass_name
| cstm.Call(func=_dataclass_name | cstm.Attribute(attr=_dataclass_name))
| cstm.Attribute(attr=_dataclass_name)
)
),
]
)


# TODO use `libcst.metadata.ScopeProvider` instead
@dataclass(slots=True, frozen=True)
class _Scope:
Expand Down Expand Up @@ -81,14 +98,8 @@ def is_class_init(self) -> bool:

@property
def is_dataclass(self) -> bool:
if cstm.matches(self.node, cstm.ClassDef()):
# Determine if dataclass
decorators = cstm.findall(self.node, cstm.Decorator())
is_dataclass = any(
cstm.findall(d, cstm.Name("dataclass")) for d in decorators
)
return is_dataclass
return False
is_dataclass = cstm.matches(self.node, _dataclass_matcher)
return is_dataclass


def _get_docstring_node(node):
Expand All @@ -106,23 +117,35 @@ def _get_docstring_node(node):
-------
docstring_node : cst.SimpleString | cst.ConcatenatedString | None
The node of the docstring if found.
docstring_value : str | None
The value of the docstring if found.
"""
docstring_node = None

docstring = node.get_docstring(clean=False)
if docstring:
# Workaround to find the exact postion of a docstring
# by using its node
string_nodes = cstm.findall(
node, cstm.SimpleString() | cstm.ConcatenatedString()
)
matching_nodes = [
node for node in string_nodes if node.evaluated_value == docstring
]
assert len(matching_nodes) == 1
docstring_node = matching_nodes[0]

return docstring_node
# Copied from https://github.com/Instagram/LibCST/blob/9275a8bf7875d08659ce7b266860138bba633410/libcst/_nodes/statement.py#L1669
body = node.body
if isinstance(body, Sequence):
if body:
expr = body[0]
else:
return (None, None)
else:
expr = body
while isinstance(expr, (cst.BaseSuite, cst.SimpleStatementLine)):
if len(expr.body) == 0:
return (None, None)
expr = expr.body[0]
if not isinstance(expr, cst.Expr):
return (None, None)

docstring_node = expr.value
if isinstance(docstring_node, (cst.SimpleString, cst.ConcatenatedString)):
docstring_value = docstring_node.evaluated_value
else:
return (None, None)
if isinstance(docstring_value, bytes):
return (None, None)

return docstring_node, docstring_value


def _log_error_with_line_context(cls):
Expand Down Expand Up @@ -897,7 +920,7 @@ def _annotations_from_node(self, node):
"""
annotations = None

docstring_node = _get_docstring_node(node)
docstring_node, docstring_value = _get_docstring_node(node)
if docstring_node:
position = self.get_metadata(
cst.metadata.PositionProvider, docstring_node
Expand All @@ -907,7 +930,7 @@ def _annotations_from_node(self, node):
)
try:
annotations = DocstringAnnotations(
docstring_node.evaluated_value,
docstring_value,
transformer=self.transformer,
reporter=reporter,
)
Expand Down
26 changes: 25 additions & 1 deletion tests/test_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import libcst.matchers as cstm
import pytest

from docstub._stubs import Py2StubTransformer, _get_docstring_node
from docstub._stubs import Py2StubTransformer, _dataclass_matcher, _get_docstring_node


class Test_get_docstring_node:
Expand Down Expand Up @@ -761,3 +761,27 @@ def foo(*args: str, **kwargs: int) -> None: ...
transformer = Py2StubTransformer()
result = transformer.python_to_stub(source)
assert expected == result


@pytest.mark.parametrize(
("decorators", "expected"),
[
("@dataclass", True),
("@dataclass(frozen=True)", True),
("@dataclasses.dataclass(frozen=True)", True),
("@dc.dataclass", True),
("", False),
("@other", False),
("@other(dataclass=True)", False),
],
)
def test_dataclass_matcher(decorators, expected):
source = dedent(
"""
{decorators}
class Foo:
pass
"""
).format(decorators=decorators)
class_def = cst.parse_statement(source)
assert cstm.matches(class_def, _dataclass_matcher) is expected
Loading