-
Notifications
You must be signed in to change notification settings - Fork 125
Add damp support to CG solver (closes #406) #752
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -65,8 +65,9 @@ class CG(Solver): | |
|
|
||
| Notes | ||
| ----- | ||
| Solve the :math:`\mathbf{y} = \mathbf{Op}\,\mathbf{x}` problem using conjugate gradient | ||
| iterations [1]_. | ||
| Solve the :math:`\mathbf{y} = (\mathbf{Op} + \epsilon\mathbf{I})\,\mathbf{x}` problem | ||
| using conjugate gradient iterations [1]_, where :math:`\epsilon` is the damping | ||
| coefficient. | ||
|
|
||
| .. [1] Hestenes, M R., Stiefel, E., “Methods of Conjugate Gradients for Solving | ||
| Linear Systems”, Journal of Research of the National Bureau of Standards. | ||
|
|
@@ -134,6 +135,7 @@ def setup( | |
| y: NDArray, | ||
| x0: NDArray | None = None, | ||
| niter: int | None = None, | ||
| damp: float = 0.0, | ||
| tol: float = 1e-4, | ||
| preallocate: bool = False, | ||
| show: bool = False, | ||
|
|
@@ -150,6 +152,8 @@ def setup( | |
| niter : :obj:`int`, optional | ||
| Number of iterations (default to ``None`` in case a user wants to | ||
| manually step over the solver) | ||
| damp : :obj:`float`, optional | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done in 6828c57 — shortened to |
||
| Damping coefficient | ||
| tol : :obj:`float`, optional | ||
| Absolute tolerance on residual norm. Stops the solver when the | ||
| residual norm is below this value. | ||
|
|
@@ -170,6 +174,7 @@ def setup( | |
| """ | ||
| self.y = y | ||
| self.niter = niter | ||
| self.damp = damp | ||
| self.tol = tol | ||
|
|
||
| self.ncp = get_array_module(y) | ||
|
|
@@ -187,6 +192,12 @@ def setup( | |
| else: | ||
| self.r = self.ncp.empty_like(self.y) | ||
| self.ncp.subtract(self.y, self.Op.matvec(x), out=self.r) | ||
| # account for the damping term in the initial residual | ||
| if self.damp != 0.0: | ||
| if not self.preallocate: | ||
| self.r = self.r - self.damp * x | ||
| else: | ||
| self.ncp.subtract(self.r, self.damp * x, out=self.r) | ||
| self.c = self.r.copy() | ||
| self.kold = self.ncp.abs(self.r.dot(self.r.conj())) | ||
|
|
||
|
|
@@ -221,6 +232,9 @@ def step(self, x: NDArray, show: bool = False) -> NDArray: | |
|
|
||
| """ | ||
| Opc = self.Op.matvec(self.c) | ||
| # add damping contribution | ||
| if self.damp != 0.0: | ||
| Opc = Opc + self.damp * self.c | ||
| cOpc = self.ncp.abs(self.c.dot(Opc.conj())) | ||
| a = self.kold / cOpc | ||
| if not self.preallocate: | ||
|
|
@@ -317,6 +331,7 @@ def solve( | |
| y: NDArray, | ||
| x0: NDArray | None = None, | ||
| niter: int = 10, | ||
| damp: float = 0.0, | ||
| tol: float = 1e-4, | ||
| preallocate: bool = False, | ||
| show: bool = False, | ||
|
|
@@ -333,6 +348,8 @@ def solve( | |
| internally as zero vector | ||
| niter : :obj:`int`, optional | ||
| Number of iterations | ||
| damp : :obj:`float`, optional | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment as above
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done in 6828c57 — trimmed to |
||
| Damping coefficient | ||
| tol : :obj:`float`, optional | ||
| Absolute tolerance on residual norm. Stops the solver when the | ||
| residual norm is below this value. | ||
|
|
@@ -360,7 +377,13 @@ def solve( | |
|
|
||
| """ | ||
| x = self.setup( | ||
| y=y, x0=x0, niter=niter, tol=tol, preallocate=preallocate, show=show | ||
| y=y, | ||
| x0=x0, | ||
| niter=niter, | ||
| damp=damp, | ||
| tol=tol, | ||
| preallocate=preallocate, | ||
| show=show, | ||
| ) | ||
| x = self.run(x, niter, show=show, itershow=itershow) | ||
| self.finalize(show) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just
Damping coefficientas in the other solversThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done in 6828c57 — shortened to
Damping coefficient.