1

我正在尝试计算可能是也可能不是紧密形式的多元函数的一堆一阶导数。为了为您提供更多上下文,我正在尝试计算选项的“希腊语”。期权价格/价值取决于很多因素:现货价格、行使价、波动率和利率等。最常用的希腊语之一称为delta,它是期权价格/价值相对于股票现货价格变化的一个单位的变化。期权的价格可能没有接近形式/分析形式,尽管为了简单起见,我在这里使用了一些接近形式。实际上,可以使用蒙特卡罗模拟计算价格。关键是,我需要一种“NumPy 友好”的方式来计算某些函数的这些一阶导数。这就是我相信很多机器学习/深度学习的人可以帮助我的地方。我参加了一些机器学习的入门课程,并且知道有一个自动微分、反向传播和其他东西的世界。我在这里使用的库是 JAX,它似乎与“numpy”有一些问题,因为错误消息如下所示:

 The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ConcreteArray(14793.626)>with<JVPTrace(level=2/0)>
  with primal = DeviceArray(14793.626, dtype=float32)
       tangent = Traced<ShapedArray(float32[]):JaxprTrace(level=1/0)>.

This error can occur when a JAX Tracer object is passed to a raw numpy function, or a method on a numpy.ndarray object. You might want to check that you are using `jnp` together with `import jax.numpy as jnp` rather than using `np` via `import numpy as np`. If this error arises on a line that involves array indexing, like `x[idx]`, it may be that the array being indexed `x` is a raw numpy.ndarray while the indices `idx` are a JAX Tracer instance; in that case, you can instead write `jax.device_put(x)[idx]`.

请注意,我正在使用“定价器”,这是一个由其他人编写的定价函数,这个定价函数是用 numpy 编写的,无法使用其他库编写。工作量太大了。我必须“应用”他用 numpy 编写的定价函数。

顺便说一句,我修改了从某个论坛看到的代码。在原始代码中,使用的函数是一个五变量函数。我所做的只是简单地添加一个名为“divyield”的变量,它就是行不通!非常感谢!我感谢任何帮助或指示!

import jax.numpy as np
from jax.scipy.stats import norm
from jax import grad
import numpy as np
import scipy.stats as si
import sympy as sy
from sympy.stats import Normal, cdf
from sympy import init_printing
import jax.numpy as jnp
#import jnp  
init_printing()

class EuropeanCall:

    def __init__(self, inputs):
    
        self.spot_price = inputs[0]
        self.strike_price = inputs[1]
        self.time_to_expiration = inputs[2]
        self.risk_free_rate = inputs[3]
        self.divyield=inputs[4]
        self.volatility = inputs[5]
    
        self.price = EuropeanCall.black_scholes_call_div(self.spot_price, self.strike_price, self.time_to_expiration,
                                             self.risk_free_rate, self.divyield, self.volatility)

        self.gradient_func = grad(EuropeanCall.black_scholes_call_div, (0, 1, 3, 4))
        self.delta, self.vega, self.theta, self.rho = self.gradient_func(inputs[0], inputs[1], inputs[2], inputs[3], 
                                                                     inputs[4],inputs[5])
        self.theta /= -365
        self.vega /= 100
        self.rho /= 100



    @staticmethod
    def black_scholes_call_div(S, K, T, r, q, sigma):
    

#S: spot price
#K: strike price
#T: time to maturity
#r: interest rate
#q: rate of continuous dividend paying asset 
#sigma: volatility of underlying asset
#r=r+cds
        d1 = (np.log(S / K) + (r - q + 0.5 * sigma ** 2) * T) / (sigma * np.sqrt(T))
        d2 = (np.log(S / K) + (r - q - 0.5 * sigma ** 2) * T) / (sigma * np.sqrt(T))

        call = (S * np.exp(-q * T) * si.norm.cdf(d1, 0.0, 1.0) - K * np.exp(-r * T) * si.norm.cdf(d2, 0.0, 1.0))

        return call

class EuropeanPut:

    def __init__(self, inputs):
    
        self.spot_price = inputs[0]
        self.strike_price = inputs[1]
        self.time_to_expiration = inputs[2]
        self.short_risk_free_rate = inputs[3]
        self.divyield=inputs[4]
        self.volatility = inputs[5]
    
        self.price = EuropeanPut.black_scholes_put_div(self.spot_price,  self.strike_price, self.time_to_expiration, 
                                            self.short_risk_free_rate,self.divyield,self.volatility)

        self.gradient_func = grad(EuropeanPut.black_scholes_put_div, (0,1,3,4))
        self.delta, self.vega, self.theta, self.rho = self.gradient_func(inputs[0], inputs[1], inputs[2], inputs[3], 
                                                                     inputs[4],inputs[5])
        self.theta /= -365
        self.vega /= 100
        self.rho /= 100



    @staticmethod
    def black_scholes_put_div(S, K, T, r, q, sigma):

#S: spot price
#K: strike price
#T: time to maturity
#r: interest rate
#q: rate of continuous dividend paying asset 
#sigma: volatility of underlying asset
#r=r+cds
        d1 = (np.log(S / K) + (r - q + 0.5 * sigma ** 2) * T) / (sigma * np.sqrt(T))
        d2 = (np.log(S / K) + (r - q - 0.5 * sigma ** 2) * T) / (sigma * np.sqrt(T))

        put = (K * np.exp(-r * T) * si.norm.cdf(-d2, 0.0, 1.0) - S * np.exp(-q * T) * si.norm.cdf(-d1, 0.0, 1.0))

        return put

              #spot_price,vol, K,T,r
inputs = np.array([3109.62, .2102, 27/365,.017,0.02,0.25])
ec = EuropeanCall(inputs.astype('float'))
print(ec.delta, ec.vega, ec.theta, ec.rho)
4

1 回答 1

2

错误消息告诉您需要做什么:

您可能想检查您是否正在使用jnpwithimport jax.numpy as jnp而不是使用npviaimport numpy as np

JAX 不能区分numpy功能,但可以区分jax.numpy功能。因此,将 , , 等替换np.log为 , np.sqrt,np.expjnp.logjnp.sqrt同样jnp.expscipy调用替换为jax.scipy调用。通过 JAX 实现所有操作后,您应该能够使用 JAX 计算梯度。

如果您正在使用无法用 JAX 重写的 numpy 中实现的第三方模块,那么您将无法直接使用 JAX 转换,包括自动微分。

于 2020-11-22T18:46:45.953 回答