Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add Database Operation tool #5513

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
feat: add Database Operation tool
  • Loading branch information
XiaoLey committed Jun 22, 2024
commit 1c7f56cf8bc28fff68456bd4d73f117083ff4695
27 changes: 27 additions & 0 deletions api/core/tools/provider/builtin/databaseoperation/_assets/icon.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.databaseoperation.tools.databasecontrol_sqlexec import DatabaseControlSqlExecTool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController


class DatabaseOperationProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
try:
DatabaseControlSqlExecTool().fork_tool_runtime(runtime={})
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
identity:
author: Xiao Ley
name: databaseoperation
label:
en_US: Database Operation
zh_Hans: 数据库操作
description:
en_US: Provide database operation capability
zh_Hans: 提供操作数据库的能力
icon: icon.svg
tags:
- utilities
- productivity
credentials_for_provider:
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
import json
from enum import Enum
from typing import Any

import sqlparse
from sqlparse.sql import Statement


class SqlTalkerBase:
OutputType = Enum('OutputType', 'JSON Markdown')

def __init__(self, host: str, port: str, username: str, password: str, dbname: str) -> None:
self.host = host
self.port = port
self.username = username
self.password = password
self.dbname = dbname

def __enter__(self):
self.connect()
return self

def __exit__(self, exc_type, exc_value, traceback):
self.close()

def connect(self) -> None:
if self.is_connected():
raise Exception("Connection is already established")

self._connect()

def close(self) -> None:
if self.is_connected():
self._close()

def is_connected(self) -> bool:
return self._is_connected()

def begin_transaction(self) -> None:
if not self.is_connected():
raise Exception("Connection is not established")

self._begin_transaction()

def commit_transaction(self) -> None:
if not self.is_connected():
raise Exception("Connection is not established")

self._commit_transaction()

def rollback_transaction(self) -> None:
if not self.is_connected():
raise Exception("Connection is not established")

self._rollback_transaction()

def exec_single(self, sql: str) -> tuple[list[str], list[dict[str, Any]]]:
return self._exec_single(sql)

def validate_sqls(self, sqls: list[str]) -> bool:
return self._validate_sqls(sqls)

def exec(self, sql: str, output_format: OutputType = OutputType.JSON) -> str:
"""
Execute one or more SQL statements

:param sql: SQL statements
:param output_format: Output format
:return: Result of the SQL statements
"""
if not self.is_connected():
raise Exception("Connection is not established")

statements = self.split_sql_statements(sql)

if not self.validate_sqls(statements):
raise Exception("Unsupported SQL statements")

result_str_list = []

self.begin_transaction()

try:
for stmt in statements:
fields_list, data_list = self.exec_single(stmt)
result_str = ''
if fields_list:
if data_list:
result_str = self.convert_to_markdown_table(fields_list, data_list) \
if output_format == self.OutputType.Markdown else json.dumps(data_list)
else:
result_str = self.convert_to_markdown_table(fields_list, []) \
if output_format == self.OutputType.Markdown else '[]'

if result_str != '':
result_str_list.append(result_str)
except Exception as e:
self.rollback_transaction()
raise e
else:
self.commit_transaction()

if output_format == self.OutputType.Markdown:
return '\n\n'.join(result_str_list)
else:
return f'[{",".join(result_str_list)}]'

def convert_to_markdown_table(self, fields: list, data: list[dict[str, Any]]) -> str:
"""
Convert a list of dictionaries to a Markdown table

:param fields: List of field names
:param data: List of dictionaries
:return: Markdown table as a string
"""
if not fields:
return ""

markdown_table = ["| " + " | ".join(fields) + " |", "| " + " | ".join(["---"] * len(fields)) + " |"]

for item in data:
# Convert each entry to a Markdown table row
row = [str(item.get(header, "")) for header in fields]
markdown_row = "| " + " | ".join(row) + " |"
markdown_table.append(markdown_row)

# Combine all rows into a single string
markdown_result = "\n".join(markdown_table)

return markdown_result

def get_sql_type(self, sql) -> str:
"""
Get the SQL type of a SQL statement

:param sql: SQL statement
:return: SQL type (e.g. SELECT, INSERT, UPDATE, DELETE)
"""
parsed = sqlparse.parse(sql)
if not parsed:
return "Empty or invalid SQL statement"

stmt: Statement = parsed[0]
first_token = stmt.token_first(skip_cm=True) # skip comments and whitespaces
return first_token.value.upper()

def split_sql_statements(self, sql) -> list[str]:
statements = sqlparse.split(sql)
return [stmt.strip() for stmt in statements if stmt.strip()]

def _connect(self) -> None:
raise NotImplementedError('SqlTalkerBase._connect() is not implemented')

def _close(self) -> None:
raise NotImplementedError('SqlTalkerBase._close() is not implemented')

def _is_connected(self) -> bool:
raise NotImplementedError('SqlTalkerBase._is_connected() is not implemented')

def _begin_transaction(self) -> None:
raise NotImplementedError('SqlTalkerBase._begin_transaction() is not implemented')

def _commit_transaction(self) -> None:
raise NotImplementedError('SqlTalkerBase._commit_transaction() is not implemented')

def _rollback_transaction(self) -> None:
raise NotImplementedError('SqlTalkerBase._rollback_transaction() is not implemented')

def _exec_single(self, sql: str) -> tuple[list[str], list[dict[str, Any]]]:
"""
Execute a single SQL statement

:param sql: SQL statement
:return: Result of the fields and data
"""
raise NotImplementedError('SqlTalkerBase._exec_single() is not implemented')

def _validate_sqls(self, sqls: list[str]) -> bool:
"""
Validate SQL statements

:param sqls: List of SQL statements
:return: True if SQL statements is supported, False otherwise
"""
raise NotImplementedError('SqlTalkerBase._filter_sql() is not implemented')
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from typing import Any, Union

import psycopg2
import pymysql

from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.provider.builtin.databaseoperation.tools._sqltalker_base import SqlTalkerBase
from core.tools.tool.builtin_tool import BuiltinTool


class CommonSqlTalker(SqlTalkerBase):
def __init__(self, host: str, port: str, username: str, password: str, dbname: str) -> None:
super().__init__(host, port, username, password, dbname)
self.conn = None

def _close(self) -> None:
self.conn.close()
self.conn = None

def _is_connected(self) -> bool:
return self.conn is not None

def _commit_transaction(self) -> None:
self.conn.commit()

def _rollback_transaction(self) -> None:
self.conn.rollback()

def _exec_single(self, sql: str) -> tuple[list[str], list[dict[str, Any]]]:
cur = self.conn.cursor()
cur.execute(sql)

fields_tuple = cur.description

if fields_tuple is None:
return None, None

fields_list = [i[0] for i in fields_tuple]
data_list = []
if cur.rowcount > 0:
data_tuple_list = cur.fetchall()
data_list = [dict(zip(fields_list, data_tuple)) for data_tuple in data_tuple_list]

return fields_list, data_list

def _validate_sqls(self, sqls: list[str]) -> bool:
# TODO: Add more SQL types in the future
supported_types = ['SELECT']

unsupported_sqls = []

for sql in sqls:
if self.get_sql_type(sql) not in supported_types:
unsupported_sqls.append(sql)

if len(unsupported_sqls) > 0:
raise Exception("Unsupported SQL statements: \n" + "\n".join([f"{i+1}. {sql}" for i, sql in enumerate(unsupported_sqls)]))

return True


class PostgreSqlTalker(CommonSqlTalker):
def _connect(self) -> None:
self.conn = psycopg2.connect(
database=self.dbname,
user=self.username,
password=self.password,
host=self.host,
port=self.port
)

def _begin_transaction(self) -> None:
self.conn.autocommit = False


class MySqlTalker(CommonSqlTalker):
def _connect(self) -> None:
self.conn = pymysql.connect(
database=self.dbname,
user=self.username,
password=self.password,
host=self.host,
port=int(self.port)
)

def _begin_transaction(self) -> None:
self.conn.autocommit(False)
self.conn.begin()


class DatabaseControlSqlExecTool(BuiltinTool):
def _invoke(self,
user_id: str,
tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
sql = tool_parameters.get('sql')
output_format = tool_parameters.get('output_format', 'json')
dbsystem = tool_parameters.get('dbsystem')
host = tool_parameters.get('host')
port = tool_parameters.get('port')
username = tool_parameters.get('username')
password = tool_parameters.get('password')
dbname = tool_parameters.get('dbname')

result_str = ''
output_type = SqlTalkerBase.OutputType.Markdown if output_format == 'markdown' else SqlTalkerBase.OutputType.JSON
if dbsystem == 'postgresql':
with PostgreSqlTalker(host, port, username, password, dbname) as sqltalker:
result_str = sqltalker.exec(sql, output_format=output_type)
elif dbsystem == 'mysql':
with MySqlTalker(host, port, username, password, dbname) as sqltalker:
result_str = sqltalker.exec(sql, output_format=output_type)
else:
raise ValueError(f"Unsupported database system: {dbsystem}")

return self.create_text_message(result_str)
Loading