Files
apiautotest/common_utils/mysql_handle.py
2023-05-27 14:45:48 +08:00

172 lines
6.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
# @Time : 2023/5/12 11:04
# @Author : chenyinhua
# @File : mysql_handle.py
# @Software: PyCharm
# @Desc: 使用pymysql模块连接mysql数据库的公共方法
import pymysql
from typing import Union
import json
from datetime import datetime
from sshtunnel import SSHTunnelForwarder # pip install sshtunnel
from loguru import logger
class MysqlServer:
"""
初始化数据库连接(支持通过SSH隧道的方式连接),并指定查询的结果集以字典形式返回
"""
def __init__(self, db_host, db_port, db_user, db_pwd, db_database, ssh=False,
**kwargs):
"""
初始化方法中, 连接mysql数据库 根据ssh参数决定是否走SSH隧道方式连接mysql数据库
"""
logger.debug("\n======================================================\n" \
"-------------数据库配置信息--------------------\n"
f"db_host: {db_host}\n" \
f"db_port: {db_port}\n" \
f"db_user: {db_user}\n" \
f"db_pwd: {db_pwd}\n" \
f"db_database: {db_database}\n" \
f"ssh: {ssh}\n" \
f"kwargs: {kwargs}\n" \
"=====================================================")
self.server = None
try:
if ssh:
self.server = SSHTunnelForwarder(
ssh_address_or_host=(kwargs.get("ssh_host"), kwargs.get("ssh_port")), # ssh 目标服务器 ip 和 port
ssh_username=kwargs.get("ssh_user"), # ssh 目标服务器用户名
ssh_password=kwargs.get("ssh_pwd"), # ssh 目标服务器用户密码
remote_bind_address=(db_host, db_port), # mysql 服务ip 和 part
local_bind_address=('127.0.0.1', 5143), # ssh 目标服务器的用于连接 mysql 或 redis 的端口,该 ip 必须为 127.0.0.1
)
self.server.start()
db_host = self.server.local_bind_host # server.local_bind_host 是 参数 local_bind_address 的 ip
db_port = self.server.local_bind_port # server.local_bind_port 是 参数 local_bind_address 的 port
# 建立连接
self.conn = pymysql.connect(host=db_host,
port=db_port,
user=db_user,
password=db_pwd,
database=db_database,
charset="utf8",
cursorclass=pymysql.cursors.DictCursor # 加上pymysql.cursors.DictCursor这个返回的就是字典
)
# 创建一个游标对象
self.cursor = self.conn.cursor()
except Exception as e:
logger.error(f"数据库连接失败:{e}")
print(f"数据库连接失败:{e}")
def query_all(self, sql):
"""
查询所有符合sql条件的数据
:param sql: 执行的sql
:return: 查询结果
"""
try:
self.conn.commit()
self.cursor.execute(sql)
data = self.cursor.fetchall()
# 关闭数据库链接和隧道
self.close()
return self.verify(data)
except Exception as e:
logger.error(f"{sql} --> 报错: {e}")
print(f"{sql} --> 报错: {e}")
raise e
def query_one(self, sql):
"""
查询符合sql条件的数据的第一条数据
:param sql: 执行的sql
:return: 返回查询结果的第一条数据
"""
try:
self.conn.commit()
self.cursor.execute(sql)
data = self.cursor.fetchone()
# 关闭数据库链接和隧道
self.close()
return self.verify(data)
except Exception as e:
logger.error(f"{sql} --> 报错: {e}")
print(f"{sql} --> 报错: {e}")
raise e
def insert(self, sql):
"""
插入数据
:param sql: 执行的sql
"""
try:
self.cursor.execute(sql)
# 提交 只要数据库更新就要commit
self.conn.commit()
# 关闭数据库链接和隧道
self.close()
except Exception as e:
logger.error(f"{sql} --> 报错: {e}")
print(f"{sql} --> 报错: {e}")
raise e
def update(self, sql):
"""
更新数据
:param sql: 执行的sql
"""
try:
self.cursor.execute(sql)
# 提交 只要数据库更新就要commit
self.conn.commit()
# 关闭数据库链接和隧道
self.close()
except Exception as e:
logger.error(f"{sql} --> 报错: {e}")
print(f"{sql} --> 报错: {e}")
raise e
def query(self, sql, one=True):
"""
根据传值决定查询一条数据还是所有
:param sql: 查询的SQL语句
:param one: 默认True. True查一条数据否则查所有
:return:
"""
try:
if one:
return self.query_one(sql)
else:
return self.query_all(sql)
except Exception as e:
logger.error(f"{sql} --> 报错: {e}")
print(f"{sql} --> 报错: {e}")
raise e
def close(self):
"""
断开游标,关闭数据库
如果开启了SSH隧道也关闭
:return:
"""
# 关闭游标
self.cursor.close()
# 关闭数据库链接
self.conn.close()
# 如果开启了SSH隧道则关闭
if self.server:
self.server.close()
def verify(self, result: dict) -> Union[dict, None]:
"""验证结果能否被json.dumps序列化"""
# 尝试变成字符串解决datetime 无法被json 序列化问题
try:
json.dumps(result)
except TypeError: # TypeError: Object of type datetime is not JSON serializable
for k, v in result.items():
if isinstance(v, datetime):
result[k] = str(v)
return result