View Javadoc
1   /*
2    * Licensed to the Hipparchus project under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The Hipparchus project licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.hipparchus.ode.nonstiff;
19  
20  
21  import java.lang.reflect.InvocationTargetException;
22  
23  import org.hipparchus.Field;
24  import org.hipparchus.CalculusFieldElement;
25  import org.hipparchus.ode.EquationsMapper;
26  import org.hipparchus.ode.FieldEquationsMapper;
27  import org.hipparchus.ode.FieldExpandableODE;
28  import org.hipparchus.ode.FieldODEStateAndDerivative;
29  import org.hipparchus.ode.FieldOrdinaryDifferentialEquation;
30  import org.hipparchus.ode.ODEStateAndDerivative;
31  import org.hipparchus.ode.sampling.AbstractFieldODEStateInterpolator;
32  import org.hipparchus.ode.sampling.FieldODEStateInterpolator;
33  import org.hipparchus.util.MathArrays;
34  import org.junit.Assert;
35  
36  public abstract class RungeKuttaFieldStateInterpolatorAbstractTest extends FieldODEStateInterpolatorAbstractTest {
37  
38      protected abstract <T extends CalculusFieldElement<T>> RungeKuttaFieldStateInterpolator<T>
39          createInterpolator(Field<T> field, boolean forward, T[][] yDotK,
40                             FieldODEStateAndDerivative<T> globalPreviousState,
41                             FieldODEStateAndDerivative<T> globalCurrentState,
42                             FieldODEStateAndDerivative<T> softPreviousState,
43                             FieldODEStateAndDerivative<T> softCurrentState,
44                             FieldEquationsMapper<T> mapper);
45  
46      protected abstract <T extends CalculusFieldElement<T>> FieldButcherArrayProvider<T>
47          createButcherArrayProvider(final Field<T> field);
48  
49      protected <T extends CalculusFieldElement<T>>
50      RungeKuttaFieldStateInterpolator<T> setUpInterpolator(final Field<T> field,
51                                                            final ReferenceFieldODE<T> eqn,
52                                                            final double t0, final double[] y0,
53                                                            final double t1) {
54  
55          // get the Butcher arrays from the field integrator
56          FieldButcherArrayProvider<T> provider = createButcherArrayProvider(field);
57          T[][] a = provider.getA();
58          T[]   b = provider.getB();
59          T[]   c = provider.getC();
60  
61          // store initial state
62          T     t          = field.getZero().add(t0);
63          T[]   fieldY     = MathArrays.buildArray(field, eqn.getDimension());
64          T[][] fieldYDotK = MathArrays.buildArray(field, b.length, -1);
65          for (int i = 0; i < y0.length; ++i) {
66              fieldY[i] = field.getZero().add(y0[i]);
67          }
68          fieldYDotK[0] = eqn.computeDerivatives(t, fieldY);
69          FieldODEStateAndDerivative<T> s0 = new FieldODEStateAndDerivative<T>(t, fieldY, fieldYDotK[0]);
70  
71          // perform one integration step, in order to get consistent derivatives
72          T h = field.getZero().add(t1 - t0);
73          for (int k = 0; k < a.length; ++k) {
74              for (int i = 0; i < y0.length; ++i) {
75                  fieldY[i] = field.getZero().add(y0[i]);
76                  for (int s = 0; s <= k; ++s) {
77                      fieldY[i] = fieldY[i].add(h.multiply(a[k][s].multiply(fieldYDotK[s][i])));
78                  }
79              }
80              fieldYDotK[k + 1] = eqn.computeDerivatives(h.multiply(c[k]).add(t0), fieldY);
81          }
82  
83          // store state at step end
84          t = field.getZero().add(t1);
85          for (int i = 0; i < y0.length; ++i) {
86              fieldY[i] = field.getZero().add(y0[i]);
87              for (int s = 0; s < b.length; ++s) {
88                  fieldY[i] = fieldY[i].add(h.multiply(b[s].multiply(fieldYDotK[s][i])));
89              }
90          }
91          FieldODEStateAndDerivative<T> s1 = new FieldODEStateAndDerivative<T>(t, fieldY,
92                                                                               eqn.computeDerivatives(t, fieldY));
93  
94          return createInterpolator(field, t1 > t0, fieldYDotK, s0, s1, s0, s1,
95                                    new FieldExpandableODE<T>(eqn).getMapper());
96  
97      }
98  
99      protected <T extends CalculusFieldElement<T>>
100     RungeKuttaStateInterpolator convertInterpolator(final FieldODEStateInterpolator<T> fieldInterpolator,
101                                                     final FieldOrdinaryDifferentialEquation<T> eqn) {
102 
103         RungeKuttaFieldStateInterpolator<T> rkFieldInterpolator =
104                         (RungeKuttaFieldStateInterpolator<T>) fieldInterpolator;
105 
106         RungeKuttaStateInterpolator regularInterpolator = null;
107         try {
108 
109             String interpolatorName = rkFieldInterpolator.getClass().getName();
110             String integratorName = interpolatorName.replaceAll("Field", "");
111             @SuppressWarnings("unchecked")
112             Class<RungeKuttaStateInterpolator> clz = (Class<RungeKuttaStateInterpolator>) Class.forName(integratorName);
113 
114             java.lang.reflect.Field fYD = RungeKuttaFieldStateInterpolator.class.getDeclaredField("yDotK");
115             fYD.setAccessible(true);
116             @SuppressWarnings("unchecked")
117             final double[][] yDotK = convertArray((T[][]) fYD.get(rkFieldInterpolator));
118 
119             java.lang.reflect.Field fMapper = AbstractFieldODEStateInterpolator.class.getDeclaredField("mapper");
120             fMapper.setAccessible(true);
121             @SuppressWarnings("unchecked")
122             EquationsMapper regularMapper = convertMapper((FieldEquationsMapper<T>) fMapper.get(rkFieldInterpolator));
123 
124             java.lang.reflect.Constructor<RungeKuttaStateInterpolator> regularInterpolatorConstructor =
125                             clz.getDeclaredConstructor(Boolean.TYPE,
126                                                        double[][].class,
127                                                        ODEStateAndDerivative.class,
128                                                        ODEStateAndDerivative.class,
129                                                        ODEStateAndDerivative.class,
130                                                        ODEStateAndDerivative.class,
131                                                        EquationsMapper.class);
132             return regularInterpolatorConstructor.newInstance(rkFieldInterpolator.isForward(),
133                                                               yDotK,
134                                                               convertODEStateAndDerivative(rkFieldInterpolator.getGlobalPreviousState()),
135                                                               convertODEStateAndDerivative(rkFieldInterpolator.getGlobalCurrentState()),
136                                                               convertODEStateAndDerivative(rkFieldInterpolator.getPreviousState()),
137                                                               convertODEStateAndDerivative(rkFieldInterpolator.getCurrentState()),
138                                                               regularMapper);
139 
140         } catch (ClassNotFoundException | InstantiationException   | IllegalAccessException    |
141                  NoSuchFieldException   | IllegalArgumentException | InvocationTargetException |
142                  NoSuchMethodException  | SecurityException e) {
143             Assert.fail(e.getLocalizedMessage());
144         }
145 
146         return regularInterpolator;
147 
148     }
149 
150 }