@@ -1304,82 +1304,63 @@ def eigvalsh(a, b, lower=True):
13041304class Expm (Op ):
13051305 """
13061306 Compute the matrix exponential of a square array.
1307-
13081307 """
13091308
13101309 __props__ = ()
1310+ gufunc_signature = "(m,m)->(m,m)"
13111311
13121312 def make_node (self , A ):
13131313 A = as_tensor_variable (A )
13141314 assert A .ndim == 2
1315- expm = matrix (dtype = A .dtype )
1316- return Apply (
1317- self ,
1318- [
1319- A ,
1320- ],
1321- [
1322- expm ,
1323- ],
1324- )
1315+
1316+ expm = matrix (dtype = A .dtype , shape = A .type .shape )
1317+
1318+ return Apply (self , [A ], [expm ])
13251319
13261320 def perform (self , node , inputs , outputs ):
13271321 (A ,) = inputs
13281322 (expm ,) = outputs
13291323 expm [0 ] = scipy_linalg .expm (A )
13301324
1331- def grad (self , inputs , outputs ):
1325+ def L_op (self , inputs , outputs , output_grads ):
1326+ # Kalbfleisch and Lawless, J. Am. Stat. Assoc. 80 (1985) Equation 3.4
1327+ # Kind of... You need to do some algebra from there to arrive at
1328+ # this expression.
13321329 (A ,) = inputs
1333- (g_out ,) = outputs
1334- return [ExpmGrad ()(A , g_out )]
1335-
1336- def infer_shape (self , fgraph , node , shapes ):
1337- return [shapes [0 ]]
1330+ (_ ,) = outputs # Outputs not used; included for signature consistency only
1331+ (A_bar ,) = output_grads
13381332
1333+ w , V = pt .linalg .eig (A , return_components = True )
13391334
1340- class ExpmGrad (Op ):
1341- """
1342- Gradient of the matrix exponential of a square array.
1335+ w = w [0 ] + 1j * w [1 ]
1336+ V = V [0 ] + 1j * V [1 ]
13431337
1344- """
1338+ exp_w = pt .exp (w )
1339+ numer = pt .sub .outer (exp_w , exp_w )
1340+ denom = pt .sub .outer (w , w )
13451341
1346- __props__ = ()
1342+ # When w_i ≈ w_j, we have a removable singularity in the expression for X, because
1343+ # lim b->a (e^a - e^b) / (a - b) = e^a (derivation left for the motivated reader)
1344+ X = pt .where (pt .abs (denom ) < 1e-8 , exp_w , numer / denom )
13471345
1348- def make_node (self , A , gw ):
1349- A = as_tensor_variable (A )
1350- assert A .ndim == 2
1351- out = matrix (dtype = A .dtype )
1352- return Apply (
1353- self ,
1354- [A , gw ],
1355- [
1356- out ,
1357- ],
1358- )
1346+ diag_idx = pt .arange (w .shape [0 ])
1347+ X = X [..., diag_idx , diag_idx ].set (exp_w )
13591348
1360- def infer_shape ( self , fgraph , node , shapes ):
1361- return [ shapes [ 0 ]]
1349+ inner = solve ( V , A_bar . T @ V ). T
1350+ result = solve ( V . T , inner * X ) @ V . T
13621351
1363- def perform (self , node , inputs , outputs ):
1364- # Kalbfleisch and Lawless, J. Am. Stat. Assoc. 80 (1985) Equation 3.4
1365- # Kind of... You need to do some algebra from there to arrive at
1366- # this expression.
1367- (A , gA ) = inputs
1368- (out ,) = outputs
1369- w , V = scipy_linalg .eig (A , right = True )
1370- U = scipy_linalg .inv (V ).T
1352+ # At this point, result is always a complex dtype. If the input was real, the output should be
1353+ # real as well (and all the imaginary parts are numerical noise)
1354+ if A .dtype not in ("complex64" , "complex128" ):
1355+ return [result .real ]
13711356
1372- exp_w = np .exp (w )
1373- X = np .subtract .outer (exp_w , exp_w ) / np .subtract .outer (w , w )
1374- np .fill_diagonal (X , exp_w )
1375- Y = U .dot (V .T .dot (gA ).dot (U ) * X ).dot (V .T )
1357+ return [result ]
13761358
1377- with warnings .catch_warnings ():
1378- warnings .simplefilter ("ignore" , ComplexWarning )
1379- out [0 ] = Y .astype (A .dtype )
1359+ def infer_shape (self , fgraph , node , shapes ):
1360+ return [shapes [0 ]]
13801361
13811362
1382- expm = Expm ()
1363+ expm = Blockwise ( Expm () )
13831364
13841365
13851366class SolveContinuousLyapunov (Op ):
0 commit comments