diff --git a/piptools/writer.py b/piptools/writer.py index f2ebece06..621e49b29 100644 --- a/piptools/writer.py +++ b/piptools/writer.py @@ -1,11 +1,14 @@ +import contextlib import io import os import re import sys +from dataclasses import dataclass from itertools import chain from typing import ( BinaryIO, Dict, + Generator, Iterable, Iterator, List, @@ -82,6 +85,39 @@ def annotation_style_line(required_by: Set[str]) -> str: return f"# via {', '.join(sorted(required_by))}" +@dataclass +class _LineWriter: + _io: io.TextIOWrapper + + def write(self, line: str) -> None: + log.info(line) + self._io.write(unstyle(line)) + self._io.write("\n") + + @classmethod + @contextlib.contextmanager + def create( + cls, buffer: BinaryIO, newline: str + ) -> Generator["_LineWriter", object, None]: + wrapper = io.TextIOWrapper( + buffer=buffer, + encoding="utf8", + newline=newline, + line_buffering=True, + ) + try: + yield cls(wrapper) + finally: + wrapper.detach() + + +class _DryRunWriter: + @staticmethod + def write(line: str) -> None: + # Bypass the log level to always print this during a dry run + log.log(line) + + class OutputWriter: def __init__( self, @@ -259,26 +295,17 @@ def write( markers: Dict[str, Marker], hashes: Optional[Dict[InstallRequirement, Set[str]]], ) -> None: - - if not self.dry_run: - dst_file = io.TextIOWrapper( - self.dst_file, - encoding="utf8", - newline=self.linesep, - line_buffering=True, - ) - try: + cmgr: Union[ + "contextlib.AbstractContextManager[_DryRunWriter]", + "contextlib.AbstractContextManager[_LineWriter]", + ] = ( + contextlib.nullcontext(_DryRunWriter()) + if self.dry_run + else _LineWriter.create(buffer=self.dst_file, newline=self.linesep) + ) + with cmgr as line_writer: for line in self._iter_lines(results, unsafe_requirements, markers, hashes): - if self.dry_run: - # Bypass the log level to always print this during a dry run - log.log(line) - else: - log.info(line) - dst_file.write(unstyle(line)) - dst_file.write("\n") - finally: - if not self.dry_run: - dst_file.detach() + line_writer.write(line) def _format_requirement( self,