1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23 package org.hipparchus.ode.nonstiff;
24
25
26 import org.hipparchus.CalculusFieldElement;
27 import org.hipparchus.Field;
28 import org.hipparchus.exception.MathIllegalArgumentException;
29 import org.hipparchus.exception.MathIllegalStateException;
30 import org.hipparchus.ode.AbstractFieldIntegrator;
31 import org.hipparchus.ode.FieldEquationsMapper;
32 import org.hipparchus.ode.FieldExpandableODE;
33 import org.hipparchus.ode.FieldODEState;
34 import org.hipparchus.ode.FieldODEStateAndDerivative;
35 import org.hipparchus.ode.FieldOrdinaryDifferentialEquation;
36 import org.hipparchus.util.MathArrays;
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61 public abstract class RungeKuttaFieldIntegrator<T extends CalculusFieldElement<T>>
62 extends AbstractFieldIntegrator<T>
63 implements FieldButcherArrayProvider<T> {
64
65
66 private final T[] c;
67
68
69 private final T[][] a;
70
71
72 private final T[] b;
73
74
75 private final T step;
76
77
78
79
80
81
82
83
84 protected RungeKuttaFieldIntegrator(final Field<T> field, final String name, final T step) {
85 super(field, name);
86 this.c = getC();
87 this.a = getA();
88 this.b = getB();
89 this.step = step.abs();
90 }
91
92
93
94
95 public T getDefaultStep() {
96 return this.step;
97 }
98
99
100
101
102
103
104 protected T fraction(final int p, final int q) {
105 return getField().getZero().add(p).divide(q);
106 }
107
108
109
110
111
112
113
114
115
116 protected abstract RungeKuttaFieldStateInterpolator<T> createInterpolator(boolean forward, T[][] yDotK,
117 FieldODEStateAndDerivative<T> globalPreviousState,
118 FieldODEStateAndDerivative<T> globalCurrentState,
119 FieldEquationsMapper<T> mapper);
120
121
122 @Override
123 public FieldODEStateAndDerivative<T> integrate(final FieldExpandableODE<T> equations,
124 final FieldODEState<T> initialState, final T finalTime)
125 throws MathIllegalArgumentException, MathIllegalStateException {
126
127 sanityChecks(initialState, finalTime);
128 setStepStart(initIntegration(equations, initialState, finalTime));
129 final boolean forward = finalTime.subtract(initialState.getTime()).getReal() > 0;
130
131
132 final int stages = c.length + 1;
133 final T[][] yDotK = MathArrays.buildArray(getField(), stages, -1);
134 final T[] yTmp = MathArrays.buildArray(getField(), equations.getMapper().getTotalDimension());
135
136
137 if (forward) {
138 if (getStepStart().getTime().add(step).subtract(finalTime).getReal() >= 0) {
139 setStepSize(finalTime.subtract(getStepStart().getTime()));
140 } else {
141 setStepSize(step);
142 }
143 } else {
144 if (getStepStart().getTime().subtract(step).subtract(finalTime).getReal() <= 0) {
145 setStepSize(finalTime.subtract(getStepStart().getTime()));
146 } else {
147 setStepSize(step.negate());
148 }
149 }
150
151
152 setIsLastStep(false);
153 do {
154
155
156 final T[] y = getStepStart().getCompleteState();
157 yDotK[0] = getStepStart().getCompleteDerivative();
158
159
160 for (int k = 1; k < stages; ++k) {
161
162 for (int j = 0; j < y.length; ++j) {
163 T sum = yDotK[0][j].multiply(a[k-1][0]);
164 for (int l = 1; l < k; ++l) {
165 sum = sum.add(yDotK[l][j].multiply(a[k-1][l]));
166 }
167 yTmp[j] = y[j].add(getStepSize().multiply(sum));
168 }
169
170 yDotK[k] = computeDerivatives(getStepStart().getTime().add(getStepSize().multiply(c[k-1])), yTmp);
171
172 }
173
174
175 for (int j = 0; j < y.length; ++j) {
176 T sum = yDotK[0][j].multiply(b[0]);
177 for (int l = 1; l < stages; ++l) {
178 sum = sum.add(yDotK[l][j].multiply(b[l]));
179 }
180 yTmp[j] = y[j].add(getStepSize().multiply(sum));
181 }
182 final T stepEnd = getStepStart().getTime().add(getStepSize());
183 final T[] yDotTmp = computeDerivatives(stepEnd, yTmp);
184 final FieldODEStateAndDerivative<T> stateTmp = equations.getMapper().mapStateAndDerivative(stepEnd, yTmp, yDotTmp);
185
186
187 setStepStart(acceptStep(createInterpolator(forward, yDotK, getStepStart(), stateTmp, equations.getMapper()),
188 finalTime));
189
190 if (!isLastStep()) {
191
192
193 final T nextT = getStepStart().getTime().add(getStepSize());
194 final boolean nextIsLast = forward ?
195 (nextT.subtract(finalTime).getReal() >= 0) :
196 (nextT.subtract(finalTime).getReal() <= 0);
197 if (nextIsLast) {
198 setStepSize(finalTime.subtract(getStepStart().getTime()));
199 }
200 }
201
202 } while (!isLastStep());
203
204 final FieldODEStateAndDerivative<T> finalState = getStepStart();
205 setStepStart(null);
206 setStepSize(null);
207 return finalState;
208
209 }
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236 public T[] singleStep(final FieldOrdinaryDifferentialEquation<T> equations,
237 final T t0, final T[] y0, final T t) {
238
239
240 final T[] y = y0.clone();
241 final int stages = c.length + 1;
242 final T[][] yDotK = MathArrays.buildArray(getField(), stages, -1);
243 final T[] yTmp = y0.clone();
244
245
246 final T h = t.subtract(t0);
247 yDotK[0] = equations.computeDerivatives(t0, y);
248
249
250 for (int k = 1; k < stages; ++k) {
251
252 for (int j = 0; j < y0.length; ++j) {
253 T sum = yDotK[0][j].multiply(a[k-1][0]);
254 for (int l = 1; l < k; ++l) {
255 sum = sum.add(yDotK[l][j].multiply(a[k-1][l]));
256 }
257 yTmp[j] = y[j].add(h.multiply(sum));
258 }
259
260 yDotK[k] = equations.computeDerivatives(t0.add(h.multiply(c[k-1])), yTmp);
261
262 }
263
264
265 for (int j = 0; j < y0.length; ++j) {
266 T sum = yDotK[0][j].multiply(b[0]);
267 for (int l = 1; l < stages; ++l) {
268 sum = sum.add(yDotK[l][j].multiply(b[l]));
269 }
270 y[j] = y[j].add(h.multiply(sum));
271 }
272
273 return y;
274
275 }
276
277 }