from sympy import *
from IPython.display import display
init_printing(use_latex='mathjax')
height = lambda M: len(M[:,0])
width = lambda M: len(M[0,:])
Ai=Matrix([[1,2,1,0,0],[3,4,0,1,0],[5,6,0,0,1]])
h=height(Ai)
w=width(Ai)-height(Ai)
A=Ai[:,0:w]
display(A)
display(Ai)
Ar,pivots=Ai.rref()
display(Ar)
L=Ar[:,0:w]
R=Ar[:,w:w+h]
display(L.transpose(),R)
B=L.transpose()*R
display(B)
display(B*A)
display(A*B)
Ai=Matrix([[1,2,3,1,0],[4,5,6,0,1]])
h=height(Ai)
w=width(Ai)-height(Ai)
A=Ai[:,0:w]
display(A)
display(Ai)
Ar,pivots=Ai.rref()
display(Ar)
P=Ar[:,0:w]
for i in range(h):
for j in range(w):
if j not in pivots:
P[i,j]=0
R=Ar[:,w:w+h]
display(P.transpose(),R)
B=P.transpose()*R
display(B)
display(A*B)
display(B*A)
Ai=Matrix([[1,2,4,1,0,0],[3,5,6,0,1,0],[7,8,9,0,0,1]])
h=height(Ai)
w=width(Ai)-height(Ai)
A=Ai[:,0:w]
display(A)
display(Ai)
Ar,pivots=Ai.rref()
display(Ar)
L=Ar[:,0:w]
R=Ar[:,w:w+h]
display(L.transpose(),R)
B=L.transpose()*R
display(B)
display(B*A)
display(A*B)