0

我有这段代码用于从 mysql 数据库中获取信息

def query_result_connect(_query):
    with SSHTunnelForwarder((ssh_host, ssh_port),
                            ssh_password=ssh_password,
                            ssh_username=ssh_user,
                            remote_bind_address=('127.0.0.1', 3306)) as server:
        connection = mdb.connect(user=sql_username,
                                 passwd=sql_password,
                                 db=sql_main_database,
                                 host='127.0.0.1',
                                 port=server.local_bind_port)
        cursor = connection.cursor()

        cursor.execute(_query)
        connection.commit()
        try:
            y = pd.read_sql(_query, connection)
            return y
        except TypeError as e:
            x = cursor.fetchall()
            return x

我想创建一个包含以下部分的函数。

with SSHTunnelForwarder((ssh_host, ssh_port),
                            ssh_password=ssh_password,
                            ssh_username=ssh_user,
                            remote_bind_address=('127.0.0.1', 3306)) as server:
        connection = mdb.connect(user=sql_username,
                                 passwd=sql_password,
                                 db=sql_main_database,
                                 host='127.0.0.1',
                                 port=server.local_bind_port)

并在 query_result_connect() 函数中执行它。问题是我不知道如何在“with”语句中包含更多代码。代码应如下所示:

# Maybe introduce some arguments
def db_connection():
    with SSHTunnelForwarder((ssh_host, ssh_port),
                            ssh_password=ssh_password,
                            ssh_username=ssh_user,
                            remote_bind_address=('127.0.0.1', 3306)) as server:
        connection = mdb.connect(user=sql_username,
                                 passwd=sql_password,
                                 db=sql_main_database,
                                 host='127.0.0.1',
                                 port=server.local_bind_port)
    #     Maybe return something
    

def query_result_connect(_query):
        # call the db_connection() function somehow.
        
        # Write the following code in a way that is within the 'with' statement of the db_connection() function.
        cursor = connection.cursor()

        cursor.execute(_query)
        connection.commit()
        try:
            y = pd.read_sql(_query, connection)
            return y
        except TypeError as e:
            x = cursor.fetchall()
            return x

谢谢

4

2 回答 2

0

是什么让“do_connection”本身成为上下文管理器?

@contextmanager
def do_connection():
    # prepare connection
    # yield connection
    # close connection (__exit__). Perhaps you even want to call "commit" here.

然后,您将像这样使用它:

with do_connection() as connection:
    cursor = connection.cursor()
    ...

这是使用上下文管理器创建数据库连接的常用方法。

于 2020-11-28T12:30:45.027 回答
0

你可以让你自己的 Connection 类,它就像一个 conext 管理器。

__enter__设置 ssh 隧道和数据库连接。
__exit__,尝试关闭游标、数据库连接和 ssh 隧道。

from sshtunnel import SSHTunnelForwarder
import psycopg2, traceback


class MyDatabaseConnection:
    def __init__(self):
        self.ssh_host = '...'
        self.ssh_port = 22
        self.ssh_user = '...'
        self.ssh_password = '...'
        self.local_db_port = 59059

    def _connect_db(self, dsn):
        try:
            self.con = psycopg2.connect(dsn)
            self.cur = self.con.cursor()
        except:
            traceback.print_exc()

    def _create_tunnel(self):
        try:
            self.tunnel = SSHTunnelForwarder(
                (self.ssh_host, self.ssh_port),
                ssh_password=self.ssh_password,
                ssh_username=self.ssh_user,
                remote_bind_address=('localhost', 5959),
                local_bind_address=('localhost', self.local_db_port)
            )
            self.tunnel.start()
            if self.tunnel.local_bind_port == self.local_db_port:
                return True
        except:
            traceback.print_exc()

    def __enter__(self):
        if self._create_tunnel():
            self._connect_db(
                "dbname=mf port=%s host='localhost' user=mf_usr" %
                self.local_db_port
            )
            return self

    def __exit__(self, *args):
        for c in ('cur', 'con', 'tunnel'):
            try:
                obj = getattr(self, c)
                obj.close()
                obj = None
                del obj
            except:
                pass


with MyDatabaseConnection() as db:
    print(db)
    db.cur.execute('Select count(*) from platforms')
    print(db.cur.fetchone())

出去:

<__main__.MyDatabaseConnection object at 0x1017cb6d0>
(8,)

注意

我正在连接到 Postgres,但这也应该可以使用mysql。可能您需要根据自己的需要进行调整。

于 2020-11-28T13:06:48.457 回答