View Javadoc
1   /*
2    * Licensed to the Hipparchus project under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The Hipparchus project licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.hipparchus.ode.nonstiff.interpolators;
19  
20  import org.hipparchus.CalculusFieldElement;
21  import org.hipparchus.Field;
22  import org.hipparchus.ode.EquationsMapper;
23  import org.hipparchus.ode.FieldEquationsMapper;
24  import org.hipparchus.ode.FieldExpandableODE;
25  import org.hipparchus.ode.FieldODEStateAndDerivative;
26  import org.hipparchus.ode.FieldOrdinaryDifferentialEquation;
27  import org.hipparchus.ode.ODEStateAndDerivative;
28  import org.hipparchus.ode.nonstiff.FieldButcherArrayProvider;
29  import org.hipparchus.ode.nonstiff.FieldODEStateInterpolatorAbstractTest;
30  import org.hipparchus.ode.sampling.AbstractFieldODEStateInterpolator;
31  import org.hipparchus.ode.sampling.FieldODEStateInterpolator;
32  import org.hipparchus.util.MathArrays;
33  
34  import java.lang.reflect.InvocationTargetException;
35  
36  import static org.junit.jupiter.api.Assertions.fail;
37  
38  public abstract class RungeKuttaFieldStateInterpolatorAbstractTest extends FieldODEStateInterpolatorAbstractTest {
39  
40      protected abstract <T extends CalculusFieldElement<T>> RungeKuttaFieldStateInterpolator<T>
41          createInterpolator(Field<T> field, boolean forward, T[][] yDotK,
42                             FieldODEStateAndDerivative<T> globalPreviousState,
43                             FieldODEStateAndDerivative<T> globalCurrentState,
44                             FieldODEStateAndDerivative<T> softPreviousState,
45                             FieldODEStateAndDerivative<T> softCurrentState,
46                             FieldEquationsMapper<T> mapper);
47  
48      protected abstract <T extends CalculusFieldElement<T>> FieldButcherArrayProvider<T>
49          createButcherArrayProvider(final Field<T> field);
50  
51      protected <T extends CalculusFieldElement<T>>
52      RungeKuttaFieldStateInterpolator<T> setUpInterpolator(final Field<T> field,
53                                                            final ReferenceFieldODE<T> eqn,
54                                                            final double t0, final double[] y0,
55                                                            final double t1) {
56  
57          // get the Butcher arrays from the field integrator
58          FieldButcherArrayProvider<T> provider = createButcherArrayProvider(field);
59          T[][] a = provider.getA();
60          T[]   b = provider.getB();
61          T[]   c = provider.getC();
62  
63          // store initial state
64          T     t          = field.getZero().add(t0);
65          T[]   fieldY     = MathArrays.buildArray(field, eqn.getDimension());
66          T[][] fieldYDotK = MathArrays.buildArray(field, b.length, -1);
67          for (int i = 0; i < y0.length; ++i) {
68              fieldY[i] = field.getZero().add(y0[i]);
69          }
70          fieldYDotK[0] = eqn.computeDerivatives(t, fieldY);
71          FieldODEStateAndDerivative<T> s0 = new FieldODEStateAndDerivative<T>(t, fieldY, fieldYDotK[0]);
72  
73          // perform one integration step, in order to get consistent derivatives
74          T h = field.getZero().add(t1 - t0);
75          for (int k = 0; k < a.length; ++k) {
76              for (int i = 0; i < y0.length; ++i) {
77                  fieldY[i] = field.getZero().add(y0[i]);
78                  for (int s = 0; s <= k; ++s) {
79                      fieldY[i] = fieldY[i].add(h.multiply(a[k][s].multiply(fieldYDotK[s][i])));
80                  }
81              }
82              fieldYDotK[k + 1] = eqn.computeDerivatives(h.multiply(c[k]).add(t0), fieldY);
83          }
84  
85          // store state at step end
86          t = field.getZero().add(t1);
87          for (int i = 0; i < y0.length; ++i) {
88              fieldY[i] = field.getZero().add(y0[i]);
89              for (int s = 0; s < b.length; ++s) {
90                  fieldY[i] = fieldY[i].add(h.multiply(b[s].multiply(fieldYDotK[s][i])));
91              }
92          }
93          FieldODEStateAndDerivative<T> s1 = new FieldODEStateAndDerivative<T>(t, fieldY,
94                                                                               eqn.computeDerivatives(t, fieldY));
95  
96          return createInterpolator(field, t1 > t0, fieldYDotK, s0, s1, s0, s1,
97                                    new FieldExpandableODE<T>(eqn).getMapper());
98  
99      }
100 
101     protected <T extends CalculusFieldElement<T>>
102     RungeKuttaStateInterpolator convertInterpolator(final FieldODEStateInterpolator<T> fieldInterpolator,
103                                                     final FieldOrdinaryDifferentialEquation<T> eqn) {
104 
105         RungeKuttaFieldStateInterpolator<T> rkFieldInterpolator =
106                         (RungeKuttaFieldStateInterpolator<T>) fieldInterpolator;
107 
108         RungeKuttaStateInterpolator regularInterpolator = null;
109         try {
110 
111             String interpolatorName = rkFieldInterpolator.getClass().getName();
112             String integratorName = interpolatorName.replaceAll("Field", "");
113             @SuppressWarnings("unchecked")
114             Class<RungeKuttaStateInterpolator> clz = (Class<RungeKuttaStateInterpolator>) Class.forName(integratorName);
115 
116             java.lang.reflect.Field fYD = RungeKuttaFieldStateInterpolator.class.getDeclaredField("yDotK");
117             fYD.setAccessible(true);
118             @SuppressWarnings("unchecked")
119             final double[][] yDotK = convertArray((T[][]) fYD.get(rkFieldInterpolator));
120 
121             java.lang.reflect.Field fMapper = AbstractFieldODEStateInterpolator.class.getDeclaredField("mapper");
122             fMapper.setAccessible(true);
123             @SuppressWarnings("unchecked")
124             EquationsMapper regularMapper = convertMapper((FieldEquationsMapper<T>) fMapper.get(rkFieldInterpolator));
125 
126             java.lang.reflect.Constructor<RungeKuttaStateInterpolator> regularInterpolatorConstructor =
127                             clz.getDeclaredConstructor(Boolean.TYPE,
128                                                        double[][].class,
129                                                        ODEStateAndDerivative.class,
130                                                        ODEStateAndDerivative.class,
131                                                        ODEStateAndDerivative.class,
132                                                        ODEStateAndDerivative.class,
133                                                        EquationsMapper.class);
134             return regularInterpolatorConstructor.newInstance(rkFieldInterpolator.isForward(),
135                                                               yDotK,
136                                                               convertODEStateAndDerivative(rkFieldInterpolator.getGlobalPreviousState()),
137                                                               convertODEStateAndDerivative(rkFieldInterpolator.getGlobalCurrentState()),
138                                                               convertODEStateAndDerivative(rkFieldInterpolator.getPreviousState()),
139                                                               convertODEStateAndDerivative(rkFieldInterpolator.getCurrentState()),
140                                                               regularMapper);
141 
142         } catch (ClassNotFoundException | InstantiationException   | IllegalAccessException    |
143                  NoSuchFieldException   | IllegalArgumentException | InvocationTargetException |
144                  NoSuchMethodException  | SecurityException e) {
145             fail(e.getLocalizedMessage());
146         }
147 
148         return regularInterpolator;
149 
150     }
151 
152 }