3939from copy import copy
4040from typing import Callable , Dict , List , Optional , Sequence , Tuple , Union
4141
42+ import numpy as np
4243import pytensor .tensor as at
4344
4445from pytensor .gradient import DisconnectedType , jacobian
4849from pytensor .graph .op import Op
4950from pytensor .graph .replace import clone_replace
5051from pytensor .graph .rewriting .basic import GraphRewriter , in2out , node_rewriter
51- from pytensor .scalar import Add , Exp , Log , Mul , Reciprocal
52+ from pytensor .scalar import Add , Exp , Log , Mul , Pow , Sqr , Sqrt
5253from pytensor .scan .op import Scan
5354from pytensor .tensor .exceptions import NotScalarConstantError
54- from pytensor .tensor .math import add , exp , log , mul , neg , reciprocal , sub , true_div
55+ from pytensor .tensor .math import (
56+ add ,
57+ exp ,
58+ log ,
59+ mul ,
60+ neg ,
61+ pow ,
62+ reciprocal ,
63+ sqr ,
64+ sqrt ,
65+ sub ,
66+ true_div ,
67+ )
5568from pytensor .tensor .rewriting .basic import (
5669 register_specialize ,
5770 register_stabilize ,
@@ -110,8 +123,11 @@ def forward(self, value: TensorVariable, *inputs: Variable) -> TensorVariable:
110123 """Apply the transformation."""
111124
112125 @abc .abstractmethod
113- def backward (self , value : TensorVariable , * inputs : Variable ) -> TensorVariable :
114- """Invert the transformation."""
126+ def backward (
127+ self , value : TensorVariable , * inputs : Variable
128+ ) -> Union [TensorVariable , Tuple [TensorVariable , ...]]:
129+ """Invert the transformation. Multiple values may be returned when the
130+ transformation is not 1-to-1"""
115131
116132 def log_jac_det (self , value : TensorVariable , * inputs ) -> TensorVariable :
117133 """Construct the log of the absolute value of the Jacobian determinant."""
@@ -320,7 +336,7 @@ def apply(self, fgraph: FunctionGraph):
320336class MeasurableTransform (MeasurableElemwise ):
321337 """A placeholder used to specify a log-likelihood for a transformed measurable variable"""
322338
323- valid_scalar_types = (Exp , Log , Add , Mul , Reciprocal )
339+ valid_scalar_types = (Exp , Log , Add , Mul , Pow )
324340
325341 # Cannot use `transform` as name because it would clash with the property added by
326342 # the `TransformValuesRewrite`
@@ -349,16 +365,64 @@ def measurable_transform_logprob(op: MeasurableTransform, values, *inputs, **kwa
349365 # The value variable must still be back-transformed to be on the natural support of
350366 # the respective measurable input.
351367 backward_value = op .transform_elemwise .backward (value , * other_inputs )
352- input_logprob = logprob (measurable_input , backward_value , ** kwargs )
368+
369+ # Some transformations, like squaring may produce multiple backward values
370+ if isinstance (backward_value , tuple ):
371+ input_logprob = at .logaddexp (
372+ * (logprob (measurable_input , backward_val , ** kwargs ) for backward_val in backward_value )
373+ )
374+ else :
375+ input_logprob = logprob (measurable_input , backward_value )
353376
354377 jacobian = op .transform_elemwise .log_jac_det (value , * other_inputs )
355378
356379 return input_logprob + jacobian
357380
358381
382+ @node_rewriter ([reciprocal ])
383+ def measurable_reciprocal_to_power (fgraph , node ):
384+ """Convert reciprocal of `MeasurableVariable`s to power."""
385+ inp = node .inputs [0 ]
386+ if not (inp .owner and isinstance (inp .owner .op , MeasurableVariable )):
387+ return None
388+
389+ rv_map_feature : Optional [PreserveRVMappings ] = getattr (fgraph , "preserve_rv_mappings" , None )
390+ if rv_map_feature is None :
391+ return None # pragma: no cover
392+
393+ # Only apply this rewrite if the variable is unvalued
394+ if inp in rv_map_feature .rv_values :
395+ return None # pragma: no cover
396+
397+ return [at .pow (inp , - 1.0 )]
398+
399+
400+ @node_rewriter ([sqr , sqrt ])
401+ def measurable_sqrt_sqr_to_power (fgraph , node ):
402+ """Convert square root or square of `MeasurableVariable`s to power form."""
403+
404+ inp = node .inputs [0 ]
405+ if not (inp .owner and isinstance (inp .owner .op , MeasurableVariable )):
406+ return None
407+
408+ rv_map_feature : Optional [PreserveRVMappings ] = getattr (fgraph , "preserve_rv_mappings" , None )
409+ if rv_map_feature is None :
410+ return None # pragma: no cover
411+
412+ # Only apply this rewrite if the variable is unvalued
413+ if inp in rv_map_feature .rv_values :
414+ return None # pragma: no cover
415+
416+ if isinstance (node .op .scalar_op , Sqr ):
417+ return [at .pow (inp , 2 )]
418+
419+ if isinstance (node .op .scalar_op , Sqrt ):
420+ return [at .pow (inp , 1 / 2 )]
421+
422+
359423@node_rewriter ([true_div ])
360- def measurable_div_to_reciprocal_product (fgraph , node ):
361- """Convert divisions involving `MeasurableVariable`s to product with reciprocal ."""
424+ def measurable_div_to_product (fgraph , node ):
425+ """Convert divisions involving `MeasurableVariable`s to products ."""
362426
363427 measurable_vars = [
364428 var for var in node .inputs if (var .owner and isinstance (var .owner .op , MeasurableVariable ))
@@ -379,9 +443,13 @@ def measurable_div_to_reciprocal_product(fgraph, node):
379443 # Check if numerator is 1
380444 try :
381445 if at .get_scalar_constant_value (numerator ) == 1 :
382- return [at .reciprocal (denominator )]
446+ # We convert the denominator directly to a power transform as this
447+ # must be the measurable input
448+ return [at .pow (denominator , - 1 )]
383449 except NotScalarConstantError :
384450 pass
451+ # We don't convert the denominator directly to a power transform as
452+ # it might not be measurable (and therefore not needed)
385453 return [at .mul (numerator , at .reciprocal (denominator ))]
386454
387455
@@ -425,7 +493,7 @@ def measurable_sub_to_neg(fgraph, node):
425493 return [at .add (minuend , at .neg (subtrahend ))]
426494
427495
428- @node_rewriter ([exp , log , add , mul , reciprocal ])
496+ @node_rewriter ([exp , log , add , mul , pow ])
429497def find_measurable_transforms (fgraph : FunctionGraph , node : Node ) -> Optional [List [Node ]]:
430498 """Find measurable transformations from Elemwise operators."""
431499
@@ -485,8 +553,18 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li
485553 transform = ExpTransform ()
486554 elif isinstance (scalar_op , Log ):
487555 transform = LogTransform ()
488- elif isinstance (scalar_op , Reciprocal ):
489- transform = ReciprocalTransform ()
556+ elif isinstance (scalar_op , Pow ):
557+ # We only allow for the base to be measurable
558+ if measurable_input_idx != 0 :
559+ return None
560+ try :
561+ (power ,) = other_inputs
562+ power = at .get_scalar_constant_value (power ).item ()
563+ # Power needs to be a constant
564+ except NotScalarConstantError :
565+ return None
566+ transform_inputs = (measurable_input , power )
567+ transform = PowerTransform (power = power )
490568 elif isinstance (scalar_op , Add ):
491569 transform_inputs = (measurable_input , at .add (* other_inputs ))
492570 transform = LocTransform (
@@ -510,12 +588,29 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li
510588
511589
512590measurable_ir_rewrites_db .register (
513- "measurable_div_to_reciprocal_product " ,
514- measurable_div_to_reciprocal_product ,
591+ "measurable_reciprocal_to_power " ,
592+ measurable_reciprocal_to_power ,
515593 "basic" ,
516594 "transform" ,
517595)
518596
597+
598+ measurable_ir_rewrites_db .register (
599+ "measurable_sqrt_sqr_to_power" ,
600+ measurable_sqrt_sqr_to_power ,
601+ "basic" ,
602+ "transform" ,
603+ )
604+
605+
606+ measurable_ir_rewrites_db .register (
607+ "measurable_div_to_product" ,
608+ measurable_div_to_product ,
609+ "basic" ,
610+ "transform" ,
611+ )
612+
613+
519614measurable_ir_rewrites_db .register (
520615 "measurable_neg_to_product" ,
521616 measurable_neg_to_product ,
@@ -601,17 +696,33 @@ def log_jac_det(self, value, *inputs):
601696 return - at .log (value )
602697
603698
604- class ReciprocalTransform (RVTransform ):
605- name = "reciprocal"
699+ class PowerTransform (RVTransform ):
700+ name = "power"
701+
702+ def __init__ (self , power = None ):
703+ if not isinstance (power , (int , float )):
704+ raise TypeError (f"Power must be integer or float, got { type (power )} " )
705+ if power == 0 :
706+ raise ValueError ("Power cannot be 0" )
707+ self .power = power
708+ super ().__init__ ()
606709
607710 def forward (self , value , * inputs ):
608- return at .reciprocal (value )
711+ at .power (value , self . power )
609712
610713 def backward (self , value , * inputs ):
611- return at .reciprocal (value )
714+ backward_value = at .power (value , (1 / self .power ))
715+
716+ # In this case the transform is not 1-to-1
717+ if (self .power > 1 ) and (self .power % 2 == 0 ):
718+ return - backward_value , backward_value
719+ else :
720+ return backward_value
612721
613722 def log_jac_det (self , value , * inputs ):
614- return - 2 * at .log (value )
723+ inv_power = 1 / self .power
724+ # Note: This fails for value==0
725+ return np .log (np .abs (inv_power )) + (inv_power - 1 ) * at .log (value )
615726
616727
617728class IntervalTransform (RVTransform ):
0 commit comments