Skip to content

Commit

Permalink
removing unnecessary loop
Browse files Browse the repository at this point in the history
  • Loading branch information
neka-nat committed Jun 15, 2020
1 parent f861583 commit 639193c
Showing 1 changed file with 16 additions and 24 deletions.
40 changes: 16 additions & 24 deletions probreg/filterreg.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def __init__(self, source=None, target_normals=None,

@staticmethod
def _maximization_step(t_source, target, estep_res, trans_p, sigma2, w=0.0,
objective_type='pt2pt', maxiter=10, tol=1.0e-4):
objective_type='pt2pt'):
m, dim = t_source.shape
n = target.shape[0]
assert dim == 2 or dim == 3, "dim must be 2 or 3."
Expand Down Expand Up @@ -190,15 +190,14 @@ def __init__(self, source=None, skinning_weight=None,

@staticmethod
def _maximization_step(t_source, target, estep_res, trans_p, sigma2, w=0.0,
objective_type='', maxiter=50, tol=1.0e-4):
objective_type=''):
m, dim = t_source.shape
n6d = dim * 2
idx_6d = lambda i: slice(i * n6d, (i + 1) * n6d)
n = target.shape[0]
n_nodes = trans_p.weights.n_nodes
assert dim == 3, "dim must be 3."
m0, m1, m2, _ = estep_res
tw = np.zeros(n_nodes * dim * 2)
c = w / (1.0 - w) * n / m
m0[m0==0] = np.finfo(np.float32).eps
m1m0 = np.divide(m1.T, m0).T
Expand All @@ -214,29 +213,22 @@ def _maximization_step(t_source, target, estep_res, trans_p, sigma2, w=0.0,
jtj_tw += w[0] * w[1] * np.dot(drxdz.T, drxdz)
a[idx_6d(pair[0]), idx_6d(pair[1])] += jtj_tw
a[idx_6d(pair[1]), idx_6d(pair[0])] += jtj_tw
for _ in range(maxiter):
x = np.zeros_like(t_source)
for pair in trans_p.weights.pairs_set():
for idx in trans_p.weights.in_pair(pair):
w = trans_p.weights[idx]['val']
q0 = dualquat_from_twist(tw[idx_6d(pair[0])])
q1 = dualquat_from_twist(tw[idx_6d(pair[1])])
x[idx] = (w[0] * q0 + w[1] * q1).transform_point(t_source[idx])
x = np.zeros_like(t_source)
for pair in trans_p.weights.pairs_set():
for idx in trans_p.weights.in_pair(pair):
x[idx] = t_source[idx]

rx = np.multiply(drxdx, (x - m1m0).T).T
b = np.zeros(n_nodes * n6d)
for pair in trans_p.weights.pairs_set():
j_tw = np.zeros(n6d)
for idx in trans_p.weights.in_pair(pair):
drxdz = drxdx[idx] * dxdz[idx]
w = trans_p.weights[idx]['val']
j_tw += w[0] * np.dot(drxdz.T, rx[idx])
b[idx_6d(pair[0])] += j_tw
rx = np.multiply(drxdx, (x - m1m0).T).T
b = np.zeros(n_nodes * n6d)
for pair in trans_p.weights.pairs_set():
j_tw = np.zeros(n6d)
for idx in trans_p.weights.in_pair(pair):
drxdz = drxdx[idx] * dxdz[idx]
w = trans_p.weights[idx]['val']
j_tw += w[0] * np.dot(drxdz.T, rx[idx])
b[idx_6d(pair[0])] += j_tw

dtw = np.linalg.lstsq(a, b, rcond=None)[0]
tw -= dtw
if np.linalg.norm(dtw) < tol:
break
tw = -np.linalg.lstsq(a, b, rcond=None)[0]

dualquats = [dualquat_from_twist(tw[idx_6d(i)]) * dq for i, dq in enumerate(trans_p.dualquats)]
if not m2 is None:
Expand Down

0 comments on commit 639193c

Please sign in to comment.