8000 feat: SQL Job Runner + endpoint by tanmoysrt · Pull Request #163 · frappe/agent · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

feat: SQL Job Runner + endpoint #163

New issue
8000

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open
66 changes: 65 additions & 1 deletion agent/database.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from __future__ import annotations

import ast
import contextlib
import json
import time
from decimal import Decimal
from typing import Any
from typing import TYPE_CHECKING, Any

import peewee

if TYPE_CHECKING:
from agent.sql_runner import SQLQuery


class Database:
def __init__(self, host, port, user, password, database):
Expand Down Expand Up @@ -637,6 +642,65 @@ def _is_tcl_query(self, query: str) -> bool:
query = query.upper().replace(" ", "")
return query.startswith(("COMMIT", "ROLLBACK", "SAVEPOINT", "BEGINTRANSACTION"))

def _run_sql_v2( # noqa: C901
self, queries: list[SQLQuery], commit: bool = False, continue_on_error: bool = False
) -> list[SQLQuery]:
try:
self.db.begin()
except Exception:
for q in queries:
q.success = False
q.error_code = "DatabaseError"
q.error_message = "Failed to connect to database"
else:
with self.db.atomic() as transaction:
is_any_query_failed = False

for q in queries:
# If previous queries failed and we don't want to continue on error
# then we need to mark all the queries as failed
if is_any_query_failed and not continue_on_error:
q.success = False
q.error_code = "PREVIOUS_QUERIES_FAILED"
q.error_message = "Failed to execute previous queries"
continue

try:
start_time = time.time()
cursor = self.db.execute_sql(q.query)
q.duration = time.time() - start_time
q.success = True
if cursor.description:
q.columns = [d[0] for d in cursor.description]
q.data = cursor.fetchall()
q.row_count = cursor.rowcount or 0
except (peewee.ProgrammingError, peewee.InternalError, peewee.OperationalError) as e:
q.success = False
is_any_query_failed = True
q.error_code = "UNKNOWN"
q.error_message = str(e)
with contextlib.suppress(Exception):
mysql_error = ast.literal_eval(q.error_message)
q.error_code = mysql_error[0]
8000 q.error_message = mysql_error[1]

except Exception:
q.success = False
is_any_query_failed = True
q.error_code = "UNKNOWN"
q.error_message = "Failed to execute query due to unknown error. Please check the query and try again later." # noqa: E501

if not commit or is_any_query_failed:
with contextlib.suppress(Exception):
transaction.rollback()
else:
transaction.commit()
finally:
with contextlib.suppress(Exception):
self.db.close()

return queries


class JSONEncoderForSQLQueryResult(json.JSONEncoder):
def default(self, obj):
Expand Down
165 changes: 165 additions & 0 deletions agent/sql_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
from __future__ import annotations

import datetime
import decimal
import re
from collections.abc import Iterable
from re import Match

from agent.base import Base
from agent.database import Database
from agent.job import Job, Step, job, step
from agent.server import Server


class SQLRunner(Base):
def __init__(
self,
database: str,
db_host: str,
db_port: int,
queries: list[str],
db_user: str | None = None,
db_password: str | None = None,
site: str | None = None,
bench: str | None = None,
read_only: bool = False,
continue_on_error: bool = False,
):
super().__init__()

self.db_host = db_host
self.db_port = db_port
self.database = database
self.db_user = db_user
self.db_password = db_password
self.site = site
self.bench = bench
self.read_only = read_only
self.continue_on_error = continue_on_error

# In case press hasn't provided the db_user and db_password, we will use the site and bench
# So, ensure that site and bench are provided
if (db_user is None or db_password is None) and (site is None or bench is None):
raise Exception("site and bench are required if db_user or db_password is not provided")

# Validate bench and site
if site or bench:
bench_obj = Server().get_bench(bench)
site_obj = bench_obj.get_site(site)
# If db_user or db_password is not provided,
# Assume we need to use the site's database credentials
if db_user is None or db_password is None:
self.db_user = site_obj.user
self.db_password = site_obj.password

self.job = None
self.step = None

self.queries: list[SQLQuery] = [SQLQuery(query) for query in queries]
self.result = []

@property
def job_record(self):
if self.job is None:
self.job = Job()
return self.job

@property
def step_record(self):
if self.step is None:
self.step = Step()
return self.step

@step_record.setter
def step_record(self, value):
self.step = value

@job("Run SQL Queries")
def run_sql_queries_job(self):
self.run_sql_queries_step()
return self.result

@step("Run SQL Queries")
def run_sql_queries_step(self):
self.result = self.run_sql_queries()

def run_sql_queries(self) -> list[dict]:
database = Database(self.db_host, self.db_port, self.db_user, self.db_password, self.database)
data = database._run_sql_v2(
self.queries, commit=(not self.read_only), continue_on_error=self.continue_on_error
)
return {"data": [q.to_json() for q in data]}


class SQLQuery:
def __init__(self, raw_query: str):
"""
Parse the SQL Queries to get the metadata

Expected format from press:
<query> /* <type>_<id> */
"""
pattern = re.compile(r"\/\*(.*?)\*\/")
try:
metadata = pattern.search(raw_query).group(1).strip().split("_", 1)
except Exception as e:
raise Exception("Invalid SQL Query Format") from e

self.id = metadata[1]
self.type = metadata[0]
self.query = raw_query
self.columns = []
self.data = []
self.row_count = 0
self.error_code = ""
self.error_message = ""
self.duration = 0.0
self.success = False

def to_json(self):
if not self.data:
self.data = []

return {
"id": self.id,
"type": self.type,
"query": self.query,
"columns": self.columns,
"data": [[self._serialize(col) for col in row] for row in self.data],
"row_count": self.row_count,
"error_code": self.error_code,
"error_message": self.error_message,
"duration": self.duration,
"success": self.success,
}

def _serialize(self, obj):
"""
Serialize non-serializable data for json

Source : https://github.com/frappe/frappe/blob/50a88149c15d419897d4d057bef1d63f79582c3a/frappe/__init__.py
"""

try:
if isinstance(obj, (int, float, str, bool)) or obj is None:
return obj

if isinstance(obj, decimal.Decimal):
return float(obj)

if isinstance(obj, (datetime.date, datetime.datetime, datetime.time)):
return obj.isoformat()

if isinstance(obj, Iterable):
return list(obj)

if isinstance(obj, Match):
return obj.string

if hasattr(obj, "__value__"):
return obj.__value__()

return str(obj)
except Exception:
return str(obj)
22 changes: 22 additions & 0 deletions agent/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from agent.proxysql import ProxySQL
from agent.security import Security
from agent.server import Server
from agent.sql_runner import SQLRunner
from agent.ssh import SSHProxy

if TYPE_CHECKING:
Expand Down Expand Up @@ -1065,6 +1066,27 @@ def physical_restore_database():
return {"job": job}


@application.route("/database/sql/execute", methods=["POST"])
def run_sql_queries():
data = request.json
async_task = data["async_task"]
runner = SQLRunner(
database=data["database"],
db_host=data["database_credential"]["host"],
db_port=data["database_credential"]["port"],
queries=data["queries"],
db_user=data["database_credential"].get("user", None),
db_password=data["database_credential"].get("password", None),
site=data.get("site"),
bench=data.get("bench"),
read_only=data.get("read_only", True),
continue_on_error=data.get("continue_on_error", False),
)
if async_task:
return {"job": runner.run_sql_queries_job()}
return runner.run_sql_queries()


@application.route("/database/binary/logs")
def get_binary_logs():
return jsonify(DatabaseServer().binary_logs)
Expand Down
0