当前位置: 首页 > news >正文

Jax(Random、Numpy)常用函数

目录

Jax

 vmap

Array

reshape

Random

PRNGKey

uniform

normal

split

 choice

Numpy

expand_dims

linspace

jax.numpy.linalg[pkg]

dot

matmul

arange

interp 

tile

reshape


Jax

jit

jax.jit(funin_shardings=UnspecifiedValueout_shardings=UnspecifiedValuestatic_argnums=Nonestatic_argnames=Nonedonate_argnums=Nonedonate_argnames=Nonekeep_unused=Falsedevice=Nonebackend=Noneinline=Falseabstracted_axes=None)[source]

注:jax.jit 是 JAX 中的一个装饰器,用于将 Python 函数编译为高效的机器代码,以提高运行速度。JIT(Just-In-Time)编译可以加速函数的执行,尤其是在循环或需要多次调用。

>>>jax.jit(lambda x,y : x + y)
<PjitFunction of <function <lambda> at 0x7ea7b402f130>>
>>>jax.jit(lambda x,y : x + y)(1,2) #process jitfunc -> lambda fun
Array(3, dtype=int32, weak_type=True)
>>>@jax.jitdef fun(x,y):return x + y
>>>fun
<PjitFunction of <function fun at 0x7ea7b402f5b0>>
>>>fun(1,2)
Array(3, dtype=int32, weak_type=True)

 vmap

jax.vmap(funin_axes=0out_axes=0axis_name=Noneaxis_size=Nonespmd_axis_name=None)[source]

注:对函数进行向量化处理,通常用于批量处理数据,而不需要显式地编写循环,函数映射调用,区别于pmap,vmap单个设备(CPU或GPU)上处理批量数据,pmap在多个设备(GPU或TPU)上并行处理数据(分布式)

>>>f_xy = lambda x,y : x + y
>>>x = jax.numpy.array([[1, 2], [3, 4]])  # shape (2, 2)
>>>y = jax.numpy.array([[5, 6], [7, 8]])  # shape (2, 2)# in this x and y array, axis 0 is row , axis 1 is col, ref shape index
# in x and y, axis -1 is shape[-1] , axis -2 is shape[-2]>>>jax.vmap(f_xy,in_axes=(0,0))(x,y)      # default out_axes = 0,row ouput
# x row + y row , need x row dim equal y row dim
Array([[ 6,  8],[10, 12]], dtype=int32)
>>>jax.vmap(f_xy,in_axes=(0,0),out_axes=1)(x,y) #show output by col
Array([[ 6,  8],[10, 12]], dtype=int32)
>>>jax.vmap(f_xy,in_axes=(0,1))(x,y) 
# x row + y col , need x row's dim equal y col's dim
Array([[ 6,  9],[ 9, 12]], dtype=int32)
>>>jax.vmap(f_xy,in_axes=(0,1),out_axes=1)(x,y) #show output by col 
Array([[ 6,  9],[ 9, 12]], dtype=int32)
>>>jax.vmap(f_xy,in_axes=(None,0))(x,y) #no vector x by row or col, x is block
# x block + y row vector, x shape (2,2) , y shape(2,2), need x row equal y row
# return shape(y_dim_2,x_dim_1,x_dim2)
Array([[[ 6,  8],[ 8, 10]],[[ 8, 10],[10, 12]]], dtype=int32)

ref:Learning about JAX :axes in vmap()

Array

reshape

abstract Array.reshape(*argsorder='C')[source]

注:Array对象的实例方法,引用jax.numpy.reshape函数

Random

PRNGKey

jax.random.PRNGKey(seed*impl=None)[source]#

注:创建一个 PRNG key,作为生成随机数的种子Seed

eg:       

>>>jax.random.PRNGKey(0)
Array([0, 0], dtype=uint32)

uniform

jax.random.uniform(keyshape=()dtype=<class 'float'>minval=0.0maxval=1.0)[source]

注:在给定的形状(shape)和数据类型(dtype)下,从 [minval, maxval) 区间内采样均匀分布的随机值

>>>k = jax.random.PRNGKey(0)
>>>jax.random.uniform(k,shape=(1,))
Array([0.41845703], dtype=float32)

normal

normal(keyshape=()dtype=<class 'float'>)[source]

注:在给定的形状shape和浮点数据类型dtype下,采样标准正态分布的随机值

>>>k = jax.random.PRNGKey(0)
>>>jax.random.normal(k,shape=(1,))
Array([-0.20584226], dtype=float32)

split

jax.random.split(keynum=2)[source]

注:用于生成伪随机数生成器(PRNG)状态的函数。它允许你从一个现有的 PRNG 状态中生成多个新的状态,从而实现随机数的可重复性和并行性。 

>>>k = jax.random.PRNGKey(1)
>>>k1,k2 = jax.random.split(k)
>>>k1
Array([2441914641, 1384938218], dtype=uint32)
>>>k2
Array([3819641963, 2025898573], dtype=uint32)

 choice

jax.random.choice(keyashape=()replace=Truep=Noneaxis=0)[source]

注:从给定数组a中按shape生成随机样本,区别于numpy.random.choice函数。default choice one elem。

>>>k = jax.random.PRNGKey(0)
>>>a = jax.numpy.array([1,2,3,4,5,6,7,8,9,0])
>>>jax.random.choice(k,a,(10,)) # random no seq
Array([9, 6, 8, 7, 8, 4, 1, 2, 3, 3], dtype=int32)
>>>jax.random.choice(k,a,(2,5))
Array([[9, 6, 8, 7, 8],[4, 1, 2, 3, 3]], dtype=int32)

Numpy

expand_dims

expand_dims(aaxis)[source]

注:为数组a的维度axis增加1维度

>>>arr = jax.numpy.array([1,2,3])
>>>arr.shape
(3,)
>>>jax.numpy.expand_dims(arr,axis=0)
Array([[1, 2, 3]], dtype=int32)
>>>jax.numpy.expand_dims(arr,axis=0).shape
(1, 3)
>>>jax.numpy.expand_dims(arr,axis=1)
Array([[1],[2],[3]], dtype=int32)
>>>jax.numpy.expand_dims(arr,axis=1).shape
(3, 1)

linspace

linspace(start: ArrayLikestop: ArrayLikenum: int = 50endpoint: bool = Trueretstep: Literal[False] = Falsedtype: DTypeLike | None = Noneaxis: int = 0*device: xc.Device | Sharding | None = None) → Array[source]

注:在给定区间[start,stop]内返回均匀间隔的数字

>>>jax.numpy.linspace(0,1,5)
Array([0.  , 0.25, 0.5 , 0.75, 1.  ], dtype=float32)

jax.numpy.linalg[pkg]

jax.numpy.linalg 是 JAX 库中用于线性代数操作的模块,对应numpy.linalg库实现

cholesky

        jax.numpy.linalg.cholesky(a*upper=False)[source]

        注:计算一个正定矩阵A的 Cholesky 分解,得到满足A=L@L.T等式的下三角或上三角矩阵L,@为Python1.5定义的矩阵乘运算(jax.numpy.matmul),L.T为L转置矩阵L^{T}

>>> d = jax.numpy.array([[2. , 1.],[1. , 2.]])
>>>jax.numpy.linalg.cholesky(d)
Array([[1.4142135 , 0.        ],[0.70710677, 1.2247449 ]], dtype=float32)>>>L = jax.numpy.linalg.cholesky(d)
>>>L@L.T
Array([[1.9999999 , 0.99999994],[0.99999994, 2.        ]], dtype=float32)
eigvalsh

jax.numpy.linalg.eigvalsh(aUPLO='L')[source]

注:计算 Hermitian 对称矩阵的特征值。对于一个给定的方阵 A,其特征值 λ 和特征向量 v满足以下关系Av=λv。cholesky分解矩阵需满足特征值>0。

>>>jax.numpy.linalg.eigvalsh(jax.numpy.array([[1,-1],[-1,1]]))
Array([0., 2.], dtype=float32)
 cond

jax.numpy.linalg.cond(xp=None)[source]

注:用于计算矩阵的条件数(condition number),这是衡量矩阵在数值计算中稳定性的重要指标。高条件数警示需要谨慎对待矩阵的计算,尤其是在求解线性方程或进行其他数值计算时,如cholesky分解。

>>>jax.numpy.linalg.cond(jax.numpy.array([[1,2],[2,1]]))
Array(3., dtype=float32)

allclose

jax.numpy.allclose(abrtol=1e-05atol=1e-08equal_nan=False)[source]

注:检查两个数组的元素是否在容差范围内近似相等,cholesky分解矩阵需满足对称性。

>>>A=jax.numpy.array([[4, 2],[2, 3]])
>>>jax.numpy.allclose(A,A.T)
Array(True, dtype=bool)
# A 为对称矩阵

dot

dot(ab*precision=Nonepreferred_element_type=None)[source]

注:用于计算两个数组的点积(dot product),对于一维数组,它计算的是向量的内积;对于二维数组(矩阵),它计算的是矩阵乘积;对于更高维度的数组,它执行的是逐元素的点积,并在最后一个轴上进行求和

  • 对于一维数组(向量)numpy.dot(a, b) 计算的是向量 a 和 b 的点积,结果是一个标量。
  • 对于二维数组(矩阵)numpy.dot(A, B) 计算的是矩阵 A 和 B 的乘积,其中 A 的列数必须与 B 的行数相等。结果是一个新的矩阵。
  • 对于更高维度的数组numpy.dot() 可以进行更复杂的广播和求和运算,但通常用于计算张量积(tensor product)的某个维度上的和。
>>>jax.numpy.dot(jax.numpy.array([1,2,3]),2)
Array([2, 4, 6], dtype=int32)
>>>jax.numpy.dot(jax.numpy.array([1,2,3]),jax.numpy.array([1,2,3]))
Array(14, dtype=int32)
>>>jax.numpy.dot(jax.numpy.array([[1,2,3],[4,5,6]]),jax.numpy.array([1,2,3]))
Array([14, 32], dtype=int32)
>>>jax.numpy.dot(jax.numpy.array([[1,2],[4,5]]),jax.numpy.array([[1,2],[4,5]]))
Array([[ 9, 12],[24, 33]], dtype=int32)
>>>a = jax.numpy.zeros((1,3,2))
>>>b = jax.numpy.zeros((1,2,4))
>>>jax.numpy.dot(a,b).shape
(1, 3, 1, 4) #matmul ret (1,3,4)

matmul

matmul(ab*precision=Nonepreferred_element_type=None)[source]#

注:于执行矩阵乘法,也称为 @ 运算符(在 Python 3.5+ 中引入),对于一维数组(向量),它计算的是内积(与 dot 相同);对于二维数组(矩阵),它计算的是矩阵乘积(与 dot 相同);对于更高维度的数组,它执行的是逐元素的矩阵乘法,并保留其他轴

  • 对于一维数组(向量)numpy.matmul(a, b) 通常不被定义为向量之间的运算,除非 a 是一个二维数组(表示多个向量)的单个行或列,并且 b 的形状与之兼容。
  • 对于二维数组(矩阵)numpy.matmul(A, B) 计算的是矩阵 A 和 B 的乘积,其中 A 的列数必须与 B 的行数相等。这与 numpy.dot() 对于二维数组的行为相同。
  • 对于更高维度的数组numpy.matmul() 遵循爱因斯坦求和约定(Einstein summation convention)的特定规则,允许在不同维度的数组之间执行矩阵乘法。这包括批处理矩阵乘法,其中每个批次独立地进行乘法运算。
>>>jax.numpy.matmul(jax.numpy.array([1,2,3]),jax.numpy.array([1,2,3]))
Array(14, dtype=int32)
>>>jax.numpy.matmul(jax.numpy.array([[1,2,3],[4,5,6]]),jax.numpy.array([1,2,3]))
Array([14, 32], dtype=int32)
>>>jax.numpy.matmul(jax.numpy.array([[1,2],[4,5]]),jax.numpy.array([[1,2],[4,5]]))
Array([[ 9, 12],[24, 33]], dtype=int32)
>>>a = jax.numpy.zeros((1,3,2))
>>>b = jax.numpy.zeros((1,2,4))
>>>jax.numpy.matmul(a,b).shape
(1, 3, 4) #dot ret (1,3,1,4)

arange

jax.numpy.arange(startstop=Nonestep=Nonedtype=None*device=None)[source]

注:default step 为1,在区间[start,stop)生成步长为1的数组,类似range函数

>>>jax.numpy.arange(0,10,1)
Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)

interp 

interp(xxpfpleft=Noneright=Noneperiod=None)[source]

注:在xp点列表中线性插值x,线性插值满足y=y_{i}+\frac{y_{i+1}-y_{i}}{x_{i+1}-x_{i}}(x-x_{i}),x\epsilon [x_{i},x_{i+1}),xi和xi+1表示xp数组相邻两点,插值x位于两点区间之间,xp点对于y值为fp,线性插值为保持符合fp = fun(xp)两点区间斜率的增量

>>>xp = jax.numpy.arange(0,10,1)
>>>fp = jax.numpy.array(range(0,10,1)) * 2
>>>x = jax.numpy.array([1,2,3])
>>>jax.numpy.interp(x,xp,fp)
Array([2., 4., 6.], dtype=float32)

tile

jax.numpy.tile(Areps)[source]

注:将A数组按reps重复化生成新Array

a = jax.numpy.array([1,2,3])
>>>jax.numpy.tile(a,2)
Array([1, 2, 3, 1, 2, 3], dtype=int32)
>>>jax.numpy.tile(a,(2,))
Array([1, 2, 3, 1, 2, 3], dtype=int32)
>>>jax.numpy.tile(a,(1,1))
Array([[1, 2, 3]], dtype=int32)
>>>jax.numpy.tile(a,(2,1)) # repeat axis 0 (row) by 2, repeat axis 1 (col) by 1
Array([[1, 2, 3],[1, 2, 3]], dtype=int32)

reshape

jax.numpy.reshape(ashape=Noneorder='C'*newshape=Deprecatedcopy=None)[source]

注:从定义Array a的shape形状为shape元组(),支持-1,推断dim数值

>>>a = jax.numpy.array([[1, 2, 3],[4, 5, 6]])
>>>jax.numpy.reshape(a,6) # equal reshape(a,(6,))
Array([1, 2, 3, 4, 5, 6], dtype=int32)
>>>jax.numpy.reshape(a,-1) # equal reshape(a,6)  -1 is inferred to be 3
Array([1, 2, 3, 4, 5, 6], dtype=int32)
>>>jax.numpy.reshape(a,(-1,2)) # equal reshape(a,(3,2)) , -1 is inferred to be 3
Array([[1, 2],[3, 4],[5, 6]], dtype=int32)
>>>jax.numpy.reshape(a,(1,-1)) # not (n,) inferred to 2 d
Array([[1, 2, 3, 4, 5, 6]], dtype=int32)

meshgrid

jax.numpy.meshgrid(*xicopy=Truesparse=Falseindexing='xy')[source]

注:创建坐标矩阵,将一维坐标向量xi(自变量x、y)转换为对应的二维坐标向量或矩阵,适用于计算网格点上的函数值(因变量z),默认indexing='xy'输出笛卡尔坐标(row为vector),indexing='ij'输出矩阵坐标(col为vector)

>>>x = jax.numpy.array([1,2,3])
>>>y = jax.numpy.array([4,5])
>>>jax.numpy.meshgrid(x,y) #default indexing='xy'
[Array([[1, 2, 3],[1, 2, 3]], dtype=int32),Array([[4, 4, 4],[5, 5, 5]], dtype=int32)]
>>>jax.numpy.meshgrid(x,y,indexing='ij')
[Array([[1, 1],[2, 2],[3, 3]], dtype=int32),Array([[4, 5],[4, 5],[4, 5]], dtype=int32)]
>>>xv,yv = jax.numpy.meshgrid(x,y,indexing='xy')
>>>xv
Array([[1, 2, 3],[1, 2, 3]], dtype=int32)
>>>yv
Array([[4, 4, 4],[5, 5, 5]], dtype=int32)
>>>xv.ravel()
Array([1, 2, 3, 1, 2, 3], dtype=int32) 
>>>yv.ravel()
Array([4, 4, 4, 5, 5, 5], dtype=int32)#Array.ravel return a view of array (no memory),  flatten return a copy of array

 自变量x shape(3,) 自变量y shape(2,),对应平面6个点, 对应值因变量z shape为(6,) 6个数值
,二维坐标可视化代码:

import jax
import matplotlib.pyplot as pltx = jax.numpy.array([1,2,3])
y = jax.numpy.array([4,5])xv,yv = jax.numpy.meshgrid(x,y,indexing='xy')z = xv + yvplt.scatter(xv.flatten(), yv.flatten(), c=z, cmap='viridis') #use xv , yv also show similar graph
plt.colorbar(label='u')
plt.xlim(0, 4)
plt.ylim(3, 6)
plt.xlabel('X axis')
plt.ylabel('Y axis')
plt.title('Grid Units Visualization')
plt.show()

尝试将点变多:

import jax
import matplotlib.pyplot as pltx = jax.numpy.linspace(0,10,100)
y = jax.numpy.linspace(0,10,100)xv,yv = jax.numpy.meshgrid(x,y,indexing='xy')z = xv + yvplt.scatter(xv.flatten(), yv.flatten(), c=z, cmap='viridis')
plt.colorbar(label='z')
plt.xlabel('X axis')
plt.ylabel('Y axis')
plt.title('Grid Units Visualization')
plt.show()

 eye

jax.numpy.eye(NM=Nonek=0dtype=None*device=None)[source]

注:用于创建单位矩阵的函数。单位矩阵是一种方阵,其主对角线上的元素为 1,其余元素为 0。

>>>jax.numpy.eye(3)
Array([[1., 0., 0.],[0., 1., 0.],[0., 0., 1.]], dtype=float32)

experimental

ode

jax.experimental.ode.odeint(func,y0,t,*args, rtol=1.4e-08, atol=1.4e-08, mxstep=inf, hmax=inf)

注:odeint基于 LSODA 算法的数值积分方法,求解微分方程(原函数),面向scipy.integrate模块的常微分函数。scipy.integrate.odeint[source]

scipy.integrate._odepack_py.odeint(func, y0, t, args=(), Dfun=None, col_deriv=0, full_output=0, ml=None, mu=None, rtol=None, atol=None, tcrit=None, h0=0.0, hmax=0.0, hmin=0.0, ixpr=0, mxstep=0, mxhnil=0, mxordn=12, mxords=5, printmessg=0, tfirst=False)

The scipy.integrate module contains functions for numerical integration. The more sophisticated of these (quaddblquadode) are out-of-scope for JAX by axes 1 & 4, since they tend to be loopy algorithms based on dynamic numbers of evaluations. jax.experimental.ode.odeint() is related, but rather limited and not under any active development.

scipy.integrate 模块包含数值积分函数。其中更复杂的(quad、dblquad、ode)超出了轴 1 和 4 的 JAX 范围,因为它们往往是基于动态评估数量的循环算法。 jax.experimental.ode.odeint() 是相关的,但相当有限,并且没有任何积极的开发。

ref:https://jax.readthedocs.io/en/latest/jep/18137-numpy-scipy-scope.html

常微分指只有一个自变量和因变量的导数关系,常微分方程(ODE):通常描述一个或多个变量(通常是时间)对一个或多个未知函数的导数关系。偏微分方程(PDE): 描述二个或多个自变量对一个或多个未知函数的偏导数关系。阶数指原函数对自变量求导次数或微分次数。

1)简单一阶常微分方程推导原函数案例

导数为函数本身的微分方程:

\frac{dy}{dx}=y

分离变量:

\frac{1}{y}dy=dx

积分:

\int \frac{1}{y}dy=\int dx=\int 1dx

得;

ln^{\left | y \right |}=x+C

求解y(x):

y=Ce^{x}

设初始条件y(x)=0,那么有C=1,则:

y=e^{x}

注:e^{x}的导数为e^{x}

>>>import jax
>>>from jax.experimental.ode import odeint
>>>y_x_fun = lambda y,x : y  # dy/dx = y
>>>y0 = 1.0
>>>x = jax.numpy.linspace(0,1,10)
>>>sovle_u = odeint(u_x_fun,u0,x)
>>>sovle_u
Array([1.       , 1.1175194, 1.2488489, 1.3956124, 1.5596234, 1.7429088,1.9477342, 2.176628 , 2.4324257, 2.7182808], dtype=float32)
>>>y_fun = lambda x : jax.numpy.exp(x) # when y(0) = 1.0 , C = 1.0
>>>y = y_fun(x)
>>>y
Array([1.       , 1.117519 , 1.2488489, 1.3956125, 1.5596235, 1.7429091,1.9477341, 2.17663  , 2.4324255, 2.7182817], dtype=float32)
>>>jax.numpy.allclose(sovle_y,y) #在一定误差范围内判断数值相等
Array(True, dtype=bool)

可视化odeint求的原函数结果,和本身原函数值:

import matplotlib.pyplot as plt
# solve by odeint
plt.subplot(1, 2, 1)
plt.plot(x, sovle_y, label='odeint', color='blue')
plt.title('Solve by odeint')
plt.xlabel('x')
plt.ylabel('y')
plt.legend()# compute by ux function
plt.subplot(1, 2, 2)
plt.plot(x, y, label='y', color='red')
plt.title('Compute by y(x) function ')
plt.xlabel('x')
plt.ylabel('y')
plt.legend()
plt.tight_layout()
plt.show()

 导数、微分与积分知识点:

\frac{dY(x))}{dx}=y(x)

那么有\int dY(x))=\int 1dY(x)=\int y(x)dx

设置x边界[0,1],那么有\int_{0}^{x}1dY(t) =Y(x)-Y(0)= \int_{0}^{x}y(t)dt ,x\epsilon [0,1]

得原函数:Y(x)=\int_{0}^{x}y(t)dt+Y(0)=\int_{0}^{x}y(t)dt+C,x\epsilon [0,1]

设定初始值:Y(0)=0,那么C=0,则有:Y(x)=\int_{0}^{x}y(t)dt,x\epsilon [0,1]

注:Y(x)为原函数,y(x)为导函数,根据导函数求原函数过程为反导(Antiderivative),上述简单一阶微分方程为原函数的导函数是原函数。

2)简单二阶常微分方程推导原函数案例

已知二阶微分方程:\frac{d^{2}y}{dt^{2}}=1

设一阶方程\frac{dy}{dt}=v,那么有\frac{d(\frac{dy}{dt})}{dt}=\frac{dv}{dt}=1

积分得:v(t)=t+C_{1},其中C_{1}为积分常数

因为:\frac{dy}{dt}=v,所以\frac{dy}{dt}=t+C_{1},左右积分得:y(t)=\frac{1}{2}t^{2}+C_{1}t+C_{2},其中C_{2}为第二次积分常数

注:y(t)为原函数,v(t)为原函数关于t得一阶导函数dy(t)/dt

使用两次odeint求解该二阶微分方程并将原函数y(t)和一阶导函数v(t)可视化:

from scipy.integrate import odeint
import numpy# v关于t得一阶导函数(本质是y关于t得二阶导函数),第一次 odeint:求解v(t) 从 dv/dt = 1
def v_t_fun(v, t): return 1.  # dv/dt = 1 , jax odeint need float, scipy odeint promise int# y关于t得一阶导函数,第二次 odeint:求解 y(t) 从 dy/dt = v(t)
def y_t_fun(y, t, V, T): # there t just one value not a vector, V is solve vector in T scope#find t in solve_v's t index in T scopei = numpy.searchsorted(T,t)return V[i-1] # dy/dt = v(t), 这里 v(t) 来自第一次 odeint 的结果# 自变量点
>>>t = jax.numpy.linspace(0, 1, 10) # this t is T scope vector
>>>t
Array([0.        , 0.11111111, 0.22222222, 0.33333334, 0.44444445,0.5555556 , 0.6666667 , 0.7777778 , 0.8888889 , 1.        ],      dtype=float32)
# 第一次求解 v(t)
>>>v0 = 0.  # 初始 v(0) = 0
>>>solve_v = odeint(v_t_fun, v0, t) #scipy odeint return 2d values,不影响画图,也可以展平1d
>>>solve_v
array([[0.        ],[0.11111111],[0.22222222],[0.33333333],[0.44444444],[0.55555556],[0.66666667],[0.77777778],[0.88888889],[1.        ]])
# 第二次求解位置 y(t)
>>>y0 = 0  # 初始位置 y(0) = 0
>>>solve_y = odeint(y_t_fun, y0, t, args=(solve_v,t)) #  scipy odeint args传参args=(solve_v,t)
>>>solve_y
array([[0.00000000e+00],[3.26174305e-09],[1.23456528e-02],[3.70369954e-02],[7.40740272e-02],[1.23456756e-01],[1.85185140e-01],[2.59259219e-01],[3.45678981e-01],[4.44444369e-01]])import matplotlib.pyplot as plt
# 绘制结果
plt.plot(t, solve_y, label='Solve y(t)')
plt.plot(t, solve_v, label="Solve v(t)")
plt.xlabel('Independent variable t')
plt.ylabel('Solution dependent variable')
plt.legend()
plt.grid()
plt.title('Solution using two odeint calls')
plt.show()

注:jax odeint和scipy odeint差别不小,jax不支持多参数微分函数传递(尝试多种方式失效,无奈),使用两次odeint需要v(t)函数,odeint求解得v值需要关于t映射到 y函数得一阶导函数dy(t)/dt=v(t)(对应y_t_fun)

使用一次odeint求解二阶常微分方程\frac{d^{2}y}{dt^{2}}=1,一次输出v(t),y(t):

import jax
from jax.experimental.ode import odeintdef yv_t_fun(Y,t):y,v = Ydydt = vdvdt = 1return jax.numpy.array([dydt,dvdt])>>>t = jax.numpy.linspace(0,1,10)
>>>Y0 = jax.numpy.array([0., 0.])
>>>solves = odeint(yv_t_fun,Y0,t)
>>>solves
Array([[0.        , 0.        ],[0.00617284, 0.11111113],[0.02469137, 0.22222222],[0.05555557, 0.33333337],[0.09876544, 0.44444445],[0.15432101, 0.5555556 ],[0.22222228, 0.66666675],[0.30246922, 0.7777778 ],[0.39506182, 0.8888889 ],[0.5000001 , 1.        ]], dtype=float32)
>>>solve_v,solve_y = solves[:,1],solves[:,0]
>>>solve_v
Array([0.        , 0.11111113, 0.22222222, 0.33333337, 0.44444445,0.5555556 , 0.66666675, 0.7777778 , 0.8888889 , 1.        ],      dtype=float32)
>>>solve_y
Array([0.        , 0.00617284, 0.02469137, 0.05555557, 0.09876544,0.15432101, 0.22222228, 0.30246922, 0.39506182, 0.5000001 ],      dtype=float32)#可视化
import matplotlib.pyplot as plt
# 绘制结果
plt.plot(t, solve_y, label='Solve y(t)')
plt.plot(t, solve_v, label="Solve v(t)")
plt.xlabel('Independent variable t')
plt.ylabel('Solution dependent variable')
plt.legend()
plt.grid()
plt.title('Solution using one odeint calls')
plt.show()


http://www.mrgr.cn/news/47211.html

相关文章:

  • 求图的各结点的入度个数
  • unity 调整skinweight (皮肤权重),解决:衣服穿模问题
  • vector(2)
  • 手写Spring第三篇番外,反射的基本使用
  • springboot民宿酒店客房管理系统-计算机毕业设计源码46755
  • Ascend C算子编程和C++基础 Lesson1-1 从人工智能到算子
  • 本地部署私人知识库的大模型!Llama 3 + RAG +大模型开源教程「动手学大模型应用开发」!
  • 实景三维赋能地下管线综合智管应用
  • 蓝牙5.4技术解析:更快、更稳定的无线通信
  • RandLA-Net PB C++
  • 项目经理是怎么慢慢废掉的?这些无意识行为可能会毁了你!
  • JS激活已有标签页(页面存在则激活,关闭则打开)
  • el-tree 修改每个层级的背景色
  • 平板外壳高精度标签粘贴应用
  • Redis SpringBoot项目学习
  • 二叉树系列(遍历/dfs/bfs)10.10
  • Linux 常用命令详细总结
  • Android -- [SelfView] 自定义多色渐变背景板
  • Java对请求参数进行校验
  • [C#]未能加载文件或程序集Newtonsoft.Json