之前有一个管理后台的需求,需要构建一个前台支持sql语句的查询工具,但很多时候往往因为sql语句写的不够严谨或慢sql导致数据库宕机,从而造成不可预知的损失,所以工作中上线sql,公司往往要求进行sql审核,通过人工审核往往效率很低,因此需要制定好规则然后对sql自动审核才是最优解,经过学习,模块sqlparse能够很好的帮助我们。
sqlparse介绍
官方文档:https://pypi.org/project/sqlparse/
Python sqlparse是一个用于解析和格式化SQL语句的Python库。它可以将复杂的SQL语句解析成易于阅读和理解的结构化格式,并提供了一些有用的功能,如SQL语句的格式化、分析等。
安装
pip install sqlparse
使用
1. 解析sql语句
import sqlparse
# 解析分隔含有多个sql的字符串,返回列表
query = 'Select a, col_2 as b from Table_A order by a desc;select * from base'
print(sqlparse.split(query))
# 解析,返回解析对象元祖集合【支持解析多个sql,sql语句需要以分号结尾才行】
sql = 'select * from foo where id in (select id from bar);select * from base_order;'
parse_sql = sqlparse.parse(sql)
print(parse_sql)
print(type(parse_sql[0]), parse_sql[0]) # 通过索引获取指定sql,类型为<class 'sqlparse.sql.Statement'>
print(type(str(parse_sql)), str(parse_sql)) # 通过str方法转为字符串
# 遍历解析数对象,会将sql语句的每一部分解析成token对象,其一般有两个属性ttype和value,组成一个tokens列表对象
for item in parse_sql:
pprint(item.tokens)
"""
[<DML 'select' at 0x2013C8FB0A0>,
<Whitespace ' ' at 0x2013C8FB100>,
<Wildcard '*' at 0x2013C8FB160>,
<Whitespace ' ' at 0x2013C8FB1C0>,
<Keyword 'from' at 0x2013C8FB220>,
<Whitespace ' ' at 0x2013C8FB280>,
<Identifier 'foo' at 0x2013C8FF190>,
<Whitespace ' ' at 0x2013C8FB340>,
<Where 'where ...' at 0x2013C8FF120>]
"""
2. 格式化sql,增加可读性
sql = 'select * from foo where id in (select id from bar);select * from base_order;'
print(sqlparse.format(sql, reindent=False, keyword_case='upper')) # 参数reindent是否缩进换行,keyword_case关键字是否改为大写
3. 分析sql语句
sql语句主要分为DDL(数据定义语言)和DML(数据操纵语言)
# sql的关键字(如:select ,as ,from),字段名,表名可被识别,且字段与表名定义为identifier类型
sql = 'select order_id,id from foo where id in (select id from bar order by id) AND rownum <= 100 group by id ;'
parsed = sqlparse.parse(sql)
# 3.1 获取语句的类型【select、insert、update、delete、alter等等】
print(parsed[0].get_type()) # SELECT
# 3.2 收集sql语句中父节点所有的关键字
keyword_list = []
for token in parsed[0].tokens:
if token.ttype is sqlparse.tokens.Keyword:
# 只会将父节点中的关键字标记,其中where不属于关键字,它另定义了类型
keyword_list.append(token.value)
print(keyword_list) # ['from', 'group by']
# 3.3 收集sql语句中父节点where子句中的关键字
where_key_list = []
for token in parsed[0].tokens:
if isinstance(token, sqlparse.sql.Where):
# 打印where子句 where id in (select id from bar order by id) and rows <100 <class 'str'>
print(token.value, type(token.value))
print(token.ttype) # None
print(token.tokens) # 如果有子句,又会拆成一个token
for where_item in token.tokens:
if where_item.ttype is sqlparse.tokens.Keyword:
where_key_list.append(where_item.value)
print("where 字句中的关键字:", where_key_list) # ['where', 'and', 'rows']
print(type(token.parent), token.parent) # 当前字句父节点的整个sql语句
print(token.flatten(), [i.value for i in token.flatten()]) # 生成器,将token代表的sql语句拆分,平铺
# 是否含有子查询
print(token.is_group)
# 3.4 提取表名
import sqlparse
from sqlparse.tokens import DML
from sqlparse.tokens import Keyword
from sqlparse.sql import Identifier
from sqlparse.sql import IdentifierList
ALL_JOIN_TYPE = ('LEFT JOIN', 'RIGHT JOIN', 'INNER JOIN', 'FULL JOIN', 'LEFT OUTER JOIN', 'FULL OUTER JOIN')
def is_subselect(parsed):
"""
是否子查询
:param parsed: T.Token
"""
if not parsed.is_group:
return False
for item in parsed.tokens:
if item.ttype is DML and item.value.upper() == 'SELECT':
return True
return False
def extract_from_part(parsed):
"""
提取from之后模块
"""
from_seen = False
for item in parsed.tokens:
if from_seen:
if is_subselect(item):
for x in extract_from_part(item):
yield x
elif item.ttype is Keyword:
from_seen = False
continue
else:
yield item
elif item.ttype is Keyword and item.value.upper() == 'FROM':
from_seen = True
def extract_join_part(parsed):
"""
提取join之后模块
"""
flag = False
for item in parsed.tokens:
if flag:
if item.ttype is Keyword:
flag = False
continue
else:
yield item
if item.ttype is Keyword and item.value.upper() in ALL_JOIN_TYPE:
flag = True
def extract_table_identifiers(token_stream):
for item in token_stream:
if isinstance(item, IdentifierList):
for identifier in item.get_identifiers():
yield identifier.get_name()
elif isinstance(item, Identifier):
yield item.get_name()
elif item.ttype is Keyword:
yield item.value
def extract_tables(sql):
"""
提取sql中的表名(select语句)
"""
from_stream = extract_from_part(sqlparse.parse(sql)[0])
join_stream = extract_join_part(sqlparse.parse(sql)[0])
return list(extract_table_identifiers(from_stream)) + list(extract_table_identifiers(join_stream))
更多的sqlparse关键字使用,可以查看源码这两个方法
class sqlparse.sql.Token(ttype, value):pass
class sqlparse.sql.TokenList(tokens=None):pass
4. 实践场景简单应用
封装增删改查各种操作数据库语句执行,并且对查询结果返回条数限制为100条
import sqlparse
sql_parse = sqlparse.parse(sql_string)
sql = str(sql_parse[0]).strip(";")
for item in sql_parse[0].tokens:
if item.ttype is DML and item.value.upper() == 'INSERT':
# 新增操作
with MyDb(db_pool) as db:
db[0].execute(sql)
db[1].commit()
code, desc = 0, "success"
data = dict(col=["result", ], value=[dict(result=desc)])
break
elif item.ttype is DML and item.value.upper() == 'DELETE':
# 删除操作
with MyDb(db_pool) as db:
db[0].execute(sql)
db[1].commit()
code, desc = 0, "success"
data = dict(col=["result", ], value=[dict(result=desc)])
break
elif item.ttype is DML and item.value.upper() == 'UPDATE':
# 修改操作
with MyDb(db_pool) as db:
db[0].execute(sql)
db[1].commit()
code, desc = 0, "success"
data = dict(col=["result", ], value=[dict(result=desc)])
break
else:
# 查询操作,对数据量做限制,最多100条
# 对父语句关键字收集
f_keyword_list = []
# 对父语句中的where语句关键字收集
f_w_keyword_list = []
for f_token in sql_parse[0].tokens:
if f_token.ttype is sqlparse.tokens.Keyword:
# 只会将父节点中的关键字标记,其中where不属于关键字,它另定义了类型
f_keyword_list.append(f_token.value.upper())
if isinstance(f_token, sqlparse.sql.Where):
where_sql_string = f_token.value.upper()
for w_token in f_token:
# print(type(w_token.ttype), w_token.ttype, w_token)
if w_token.ttype is sqlparse.tokens.Keyword:
f_w_keyword_list.append(w_token.value.upper())
if w_token.ttype is sqlparse.tokens.Name.Builtin and w_token.value == "ROWNUM":
f_w_keyword_list.append(w_token.value)
# 如果是mysql,确认父节点中没有limit关键字,可以在sql后直接加上limit即可
engine = db_info["ENGINE"]
# 不允许直接select * from table查询需加上where条件
if "WHERE" not in f_w_keyword_list:
sql = f"{sql} where 1=1"
if engine == "mysql":
if "LIMIT" not in f_keyword_list:
sql = sql + ' limit 100'
else:
if "ROWNUM" not in f_w_keyword_list:
sql = f"select * from ({sql}) where rownum <=100"
with MyDb(db_pool) as db:
db[0].execute(sql)
col_names = [col[0] for col in db[0].description]
# 解决clob字段最后json失败,需提前解析出来,不用fetchall
row = []
for t in db[0]:
data_row = []
for k in t:
if isinstance(k, cx_Oracle.LOB):
data_row.append(k.read())
else:
data_row.append(k)
row.append(tuple(data_row))
# row = db[0].fetchall()
data["col"] = col_names
db_data = []
if row:
for i in row:
t_map = dict(zip(col_names, i))
db_data.append(t_map)
data["value"] = db_data
code, desc = 0, "success"
break
领取专属 10元无门槛券
私享最新 技术干货