-2

在该__enter__方法中,我想返回一个可在 Rust 和 Python 中访问的对象,以便 Rust 能够更新对象中的值,而 Python 可以读取更新的值。

我想要这样的东西:

#![feature(specialization)]

use std::thread;

use pyo3::prelude::*;
use pyo3::types::{PyType, PyAny, PyDict};
use pyo3::exceptions::ValueError;
use pyo3::PyContextProtocol;
use pyo3::wrap_pyfunction;

#[pyclass]
#[derive(Debug, Clone)]
pub struct Statistics {
    pub files: u32,
    pub errors: Vec<String>,
}

fn counter(
    root_path: &str,
    statistics: &mut Statistics,
) {
    statistics.files += 1;
    statistics.errors.push(String::from("Foo"));
}

#[pyfunction]
pub fn count(
    py: Python,
    root_path: &str,
) -> PyResult<PyObject> {
    let mut statistics = Statistics { 
        files: 0,
        errors: Vec::new(),
    };

    let rc: std::result::Result<(), std::io::Error> = py.allow_threads(|| {
        counter(root_path, &mut statistics);
        Ok(())
    });
    let pyresult = PyDict::new(py);
    match rc {
        Err(e) => { pyresult.set_item("error", e.to_string()).unwrap();
                    return Ok(pyresult.into())
                  },
        _ => ()
    }
    pyresult.set_item("files", statistics.files).unwrap();
    pyresult.set_item("errors", statistics.errors).unwrap();
    Ok(pyresult.into())
}

#[pyclass]
#[derive(Debug)]
pub struct Count {
    root_path: String,
    exit_called: bool,
    thr: Option<thread::JoinHandle<()>>,
    statistics: Statistics,
}

#[pymethods]
impl Count {
    #[new]
    fn __new__(
        obj: &PyRawObject,
        root_path: &str,
    ) {
        obj.init(Count {
            root_path: String::from(root_path),
            exit_called: false,
            thr: None,
            statistics: Statistics { 
                files: 0,
                errors: Vec::new(),
            },
        });
    }

    #[getter]
    fn statistics(&self) -> PyResult<Statistics> {
       Ok(Statistics { files: self.statistics.files,
                       errors: self.statistics.errors.to_vec(), })
    }
}

#[pyproto]
impl<'p> PyContextProtocol<'p> for Count {
    fn __enter__(&mut self) -> PyResult<Py<Count>> {
        let gil = GILGuard::acquire();
        self.thr = Some(thread::spawn(|| {
            counter(self.root_path.as_ref(), &mut self.statistics)
        }));
        Ok(PyRefMut::new(gil.python(), *self).unwrap().into())
    }

    fn __exit__(
        &mut self,
        ty: Option<&'p PyType>,
        _value: Option<&'p PyAny>,
        _traceback: Option<&'p PyAny>,
    ) -> PyResult<bool> {
        self.thr.unwrap().join();
        let gil = GILGuard::acquire();
        self.exit_called = true;
        if ty == Some(gil.python().get_type::<ValueError>()) {
            Ok(true)
        } else {
            Ok(false)
        }
    }
}

#[pymodule(count)]
fn init(_py: Python, m: &PyModule) -> PyResult<()> {
    m.add_class::<Count>()?;
    m.add_wrapped(wrap_pyfunction!(count))?;
    Ok(())
}

但我收到以下错误:

error[E0477]: the type `[closure@src/lib.rs:90:39: 92:10 self:&mut &'p mut Count]` does not fulfill the required lifetime
  --> src/lib.rs:90:25
   |
90 |         self.thr = Some(thread::spawn(|| {
   |                         ^^^^^^^^^^^^^
   |
   = note: type must satisfy the static lifetime
4

1 回答 1

1

我找到了解决方案。使用受保护的引用可以解决问题:

#![feature(specialization)]

use std::{thread, time};
use std::sync::{Arc, Mutex};

extern crate crossbeam_channel as channel;
use channel::{Sender, Receiver, TryRecvError};

use pyo3::prelude::*;
use pyo3::types::{PyType, PyAny};
use pyo3::exceptions::ValueError;
use pyo3::PyContextProtocol;

#[pyclass]
#[derive(Debug, Clone)]
pub struct Statistics {
    pub files: u32,
    pub errors: Vec<String>,
}

pub fn counter(
    statistics: Arc<Mutex<Statistics>>,
    cancel: &Receiver<()>,
) {
    for _ in 1..15 {
        thread::sleep(time::Duration::from_millis(100));
        {
            let mut s = statistics.lock().unwrap();
            s.files += 1;
        }
        match cancel.try_recv() {
            Ok(_) | Err(TryRecvError::Disconnected) => {
                println!("Terminating.");
                break;
            }
            Err(TryRecvError::Empty) => {}
        }
    }
    {
        let mut s = statistics.lock().unwrap();
        s.errors.push(String::from("Foo"));
    }
}

#[pyclass]
#[derive(Debug)]
pub struct Count {
    exit_called: bool,
    statistics: Arc<Mutex<Statistics>>,
    thr: Option<thread::JoinHandle<()>>,
    cancel: Option<Sender<()>>,
}

#[pymethods]
impl Count {
    #[new]
    fn __new__(obj: &PyRawObject) {
        obj.init(Count {
            exit_called: false,
            statistics: Arc::new(Mutex::new(Statistics {
                files: 0,
                errors: Vec::new(),
            })),
            thr: None,
            cancel: None,
        });
    }

    #[getter]
    fn statistics(&self) -> PyResult<u32> {
        let s = Arc::clone(&self.statistics).lock().unwrap().files;
        Ok(s)
     }
}

#[pyproto]
impl<'p> PyContextProtocol<'p> for Count {
    fn __enter__(&'p mut self) -> PyResult<()> {
        let statistics = self.statistics.clone();
        let (sender, receiver) = channel::bounded(1);
        self.cancel = Some(sender);
        self.thr = Some(thread::spawn(move || {
            counter(statistics, &receiver)
        }));
        Ok(())
    }

    fn __exit__(
        &mut self,
        ty: Option<&'p PyType>,
        _value: Option<&'p PyAny>,
        _traceback: Option<&'p PyAny>,
    ) -> PyResult<bool> {
        let _ = self.cancel.as_ref().unwrap().send(());
        self.thr.take().map(thread::JoinHandle::join);
        let gil = GILGuard::acquire();
        self.exit_called = true;
        if ty == Some(gil.python().get_type::<ValueError>()) {
            Ok(true)
        } else {
            Ok(false)
        }
    }
}

#[pyproto]
impl pyo3::class::PyObjectProtocol for Count {
    fn __str__(&self) -> PyResult<String> {
        Ok(format!("{:?}", self))
    }
}

#[pymodule(count)]
fn init(_py: Python, m: &PyModule) -> PyResult<()> {
    m.add_class::<Count>()?;
    Ok(())
}

现在我可以运行以下代码:

import time

import count

c = count.Count()

with c:
    for _ in range(5):
        print(c.statistics)
        time.sleep(0.1)

如示例所示,线程取消也有效,尽管使用 crate 可能更好的解决方案thread_control

于 2020-01-04T06:01:17.067 回答