0

我想提供一个应用程序,该应用程序使用烧瓶和 gunicorn 在 googles JAX 框架中处理数据。

如果在烧瓶内运行,一切正常。一旦我在 gunicorn 中运行应用程序,每个与 jax 相关的部分都会导致工作进程死亡,而不会引发任何异常。我尝试同时使用同步和 gthreads 作为工作线程,但结果相同。

我试图通过在 ThreadPoolExecutor 和 ProcessPoolExecutor 中包装相同的调用来查看 JAX 是否可以处理多处理和多线程,并且可以完美地工作。

import jax

import logging
logging.basicConfig(format="%(asctime)s | %(name)12.12s | %(message)s")
logger = logging.getLogger("Main")
logger.setLevel(logging.DEBUG)

from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed

from fit.optimization.vectorize import BatchNumpyInterface, batch_calculate_fit

def warmup():
    logger.debug("Warmup")
    data = BatchNumpyInterface.generate_dummy()
    batch_calculate_fit(data)
    logger.debug("Warmed up")

def run_fn():
    logger.debug("Creating data")
    data = BatchNumpyInterface.generate_dummy(100)
    
    logger.debug("Predicting %s in batches", 100)
    result = batch_calculate_fit(data)

    logger.debug("Done")
    return float(result[0][0]), float(result[1][0])

#with ThreadPoolExecutor(max_workers=4) as executor:
with ProcessPoolExecutor(max_workers=4) as executor:
    results = []
    for i in range(4):
        results.append(executor.submit(warmup))

    for res in as_completed(results):
        continue

    results = []
    for i in range(10):
        future = executor.submit(run_fn)
        results.append(future)

    for res in as_completed(results):
        print(res.result())


在调试期间,每次我检查 JAX DeviceArray 时,应用程序都会崩溃。使用 JAX 跳过第一个计算也是如此。

任何帮助将非常感激!

4

0 回答 0