

公司的数据分析师,提交一个sql, 一般都三四百行。由于数据安全的需要,不能开放所有的数据库和数据表给数据分析师查询,所以需要解析sql中的数据库和表,与权限管理系统中记录的数据库和表权限信息比对,实现非法查询的拦截。


在解决这个问题前,现在github找了一下轮子,发现python下面除了sql parse没什么好的解析数据库和表的轮轮。到是在java里面找到presto-parser解析的比较准。于是自己结合sql parse源码写了个类,供大家参考,测试了一下,检测还是准的。


# coding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import sqlparse
from sqlparse.sql import Identifier, IdentifierList
from sqlparse.tokens import Keyword, Name


class BaseExtractor(object):
  def __init__(self, sql_statement):
    self.sql = sqlparse.format(sql_statement, reindent=True, keyword_case='upper')
    self._table_names = set()
    self._alias_names = set()
    self._limit = None
    self._parsed = sqlparse.parse(self.stripped())
    for statement in self._parsed:
      self._limit = self._extract_limit_from_query(statement)
    self._table_names = self._table_names - self._alias_names

  def tables(self):
    return self._table_names

  def limit(self):
    return self._limit

  def is_select(self):
    return self._parsed[0].get_type() == 'SELECT'

  def is_explain(self):
    return self.stripped().upper().startswith('EXPLAIN')

  def is_readonly(self):
    return self.is_select() or self.is_explain()

  def stripped(self):
    return self.sql.strip(' \t\n;')

  def get_statements(self):
    statements = []
    for statement in self._parsed:
      if statement:
        sql = str(statement).strip(' \n;\t')
        if sql:
    return statements

  def __precedes_table_name(token_value):
    for keyword in PRECEDES_TABLE_NAME:
      if keyword in token_value:
        return True
    return False

  def get_full_name(identifier):
    if len(identifier.tokens) > 1 and identifier.tokens[1].value == '.':
      return '{}.{}'.format(identifier.tokens[0].value,
    return identifier.get_real_name()

  def __is_result_operation(keyword):
    for operation in RESULT_OPERATIONS:
      if operation in keyword.upper():
        return True
    return False

  def __is_identifier(token):
    return isinstance(token, (IdentifierList, Identifier))

  def __process_identifier(self, identifier):
    if '(' not in '{}'.format(identifier):

    # store aliases
    if hasattr(identifier, 'get_alias'):
    if hasattr(identifier, 'tokens'):
      # some aliases are not parsed properly
      if identifier.tokens[0].ttype == Name:

  def as_create_table(self, table_name, overwrite=False):
    exec_sql = ''
    sql = self.stripped()
    if overwrite:
      exec_sql = 'DROP TABLE IF EXISTS {};\n'.format(table_name)
    exec_sql += 'CREATE TABLE {} AS \n{}'.format(table_name, sql)
    return exec_sql

  def __extract_from_token(self, token):
    if not hasattr(token, 'tokens'):

    table_name_preceding_token = False

    for item in token.tokens:
      if item.is_group and not self.__is_identifier(item):

      if item.ttype in Keyword:
        if self.__precedes_table_name(item.value.upper()):
          table_name_preceding_token = True

      if not table_name_preceding_token:

      if item.ttype in Keyword or item.value == ',':
        if (self.__is_result_operation(item.value) or
            item.value.upper() == ON_KEYWORD):
          table_name_preceding_token = False
        # FROM clause is over

      if isinstance(item, Identifier):

      if isinstance(item, IdentifierList):
        for token in item.tokens:
          if self.__is_identifier(token):

  def _get_limit_from_token(self, token):
    if token.ttype == sqlparse.tokens.Literal.Number.Integer:
      return int(token.value)
    elif token.is_group:
      return int(token.get_token_at_offset(1).value)

  def _extract_limit_from_query(self, statement):
    limit_token = None
    for pos, item in enumerate(statement.tokens):
      if item.ttype in Keyword and item.value.lower() == 'limit':
        limit_token = statement.tokens[pos + 2]
        return self._get_limit_from_token(limit_token)

  def get_query_with_new_limit(self, new_limit):
    if not self._limit:
      return self.sql + ' LIMIT ' + str(new_limit)
    limit_pos = None
    tokens = self._parsed[0].tokens
    # Add all items to before_str until there is a limit
    for pos, item in enumerate(tokens):
      if item.ttype in Keyword and item.value.lower() == 'limit':
        limit_pos = pos
    limit = tokens[limit_pos + 2]
    if limit.ttype == sqlparse.tokens.Literal.Number.Integer:
      tokens[limit_pos + 2].value = new_limit
    elif limit.is_group:
      tokens[limit_pos + 2].value = (
        '{}, {}'.format(next(limit.get_identifiers()), new_limit)

    str_res = ''
    for i in tokens:
      str_res += str(i.value)
    return str_res

class SqlExtractor(BaseExtractor):

  def get_full_name(identifier, including_dbs=False):
    if len(identifier.tokens) > 1 and identifier.tokens[1].value == '.':
      a = identifier.tokens[0].value
      b = identifier.tokens[2].value
      db_table = (a, b)
      full_tree = '{}.{}'.format(a, b)
      if len(identifier.tokens) == 3:
        return full_tree
        i = identifier.tokens[3].value
        c = identifier.tokens[4].value
        if i == ' ':
          return full_tree
        full_tree = '{}.{}.{}'.format(a, b, c)
        return full_tree
    return None, None

if __name__ == '__main__':
{'mysql8.dataview_tprc.tprc_task', 'hive.bdc_dwd.dw_mk_order', 'mysql4.dataview_fenxiao.fx_order_task', 'mysql4.dataview_fenxiao.fx_order', 'hive.bdc_dwd.dw_mk_order_business', 'mysql7.dataview_trade.mk_order_merchant', 'mysql4.dataview_scrm.sc_tprc_product_info', 'hive.bdc_dwd.dw_fn_deal_asyn_order', 'hive.bdc_dwd.dw_fact_task_ss_daily', 'mysql7.dataview_trade.ddc_product_info', 'hive.bdc_dwd.dw_mk_order_status'}


