8000 flake8 by skourta · Pull Request #18 · Tiramisu-Compiler/TiraLib · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

flake8 #18

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions athena/search_methods/sequential_parallelization.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,22 @@ def parallelize_first_legal_outermost(
tmp_schedule = schedule.copy()
for candidate in candidates_per_root[root]:
for node in candidate:
comps = tiramisu_program.tree.get_iterator_subtree_computations(node)
comps = (
tiramisu_program.tree.get_iterator_subtree_computations(
node
)
)
tmp_schedule.add_optimizations(
[
Parallelization(
[(comps[0], tiramisu_program.tree.iterators[node].level)]
[
(
comps[0],
tiramisu_program.tree.iterators[
node
].level,
)
]
)
]
)
Expand Down
119 changes: 69 additions & 50 deletions athena/tiramisu/compiling_service.py

Large diffs are not rendered by default.

34 changes: 21 additions & 13 deletions athena/tiramisu/function_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@
schedule_str = argv[2];

std::string function_name = "{name}";

{body}

schedule_str_to_result_str(function_name, schedule_str, operation, {buffers});
return 0;
}}
"""
""" # noqa: E501


class ResultInterface:
Expand All @@ -65,20 +65,26 @@ def __init__(self, result_str: bytes) -> None:
self.success = result_dict["success"]

# convert exec_times to list of floats
self.exec_times = [float(x) for x in result_dict["exec_times"].split()] if result_dict["exec_times"] else []
self.exec_times = (
[float(x) for x in result_dict["exec_times"].split()]
if result_dict["exec_times"]
else []
)

self.additional_info = result_dict["additional_info"]

def __str__(self) -> str:
isl_ast = self.isl_ast.replace("\n", ",")
return f"ResultInterface(name={self.name},legality={self.legality},isl_ast={isl_ast},exec_times={self.exec_times},success={self.success})"
return f"ResultInterface(name={self.name},legality={self.legality},isl_ast={isl_ast},exec_times={self.exec_times},success={self.success})" # noqa: E501

def __repr__(self) -> str:
return self.__str__()


class FunctionServer:
def __init__(self, tiramisu_program: "TiramisuProgram", reuseServer: bool = False):
def __init__(
self, tiramisu_program: "TiramisuProgram", reuseServer: bool = False
):
if not BaseConfig.base_config:
raise ValueError("BaseConfig not initialized")

Expand All @@ -97,8 +103,10 @@ def __init__(self, tiramisu_program: "TiramisuProgram", reuseServer: bool = Fals
return

# Generate the server code
server_code = FunctionServer._generate_server_code_from_original_string(
tiramisu_program
server_code = (
FunctionServer._generate_server_code_from_original_string(
tiramisu_program
)
)

# Write the server code to a file
Expand Down Expand Up @@ -129,13 +137,13 @@ def _generate_server_code_from_original_string(
original_str = tiramisu_program.original_str
# Generate function
body = re.findall(
r"int main\([\w\s,*]+\)\s*\{([\W\w\s]*)tiramisu::codegen", original_str
r"int main\([\w\s,*]+\)\s*\{([\W\w\s]*)tiramisu::codegen",
original_str,
)[0]
name = re.findall(r"tiramisu::init\(\"(\w+)\"\);", original_str)[0]
# Remove the wrapper include from the original string
wrapper_str = f'#include "{name}_wrapper.h"'
original_str = original_str.replace(wrapper_str, f"// {wrapper_str}")
code_gen_line = re.findall(r"tiramisu::codegen\({.+;", original_str)[0]
buffers_vector = re.findall(
r"(?<=tiramisu::codegen\()\{[&\w,\s]+\}", original_str
)[0]
Expand All @@ -156,7 +164,7 @@ def _compile_server_code(self):
]
)

compileCommand = f"cd {BaseConfig.base_config.workspace} && {env_vars} && export FUNC_NAME={self.tiramisu_program.name} && $CXX -I$TIRAMISU_ROOT/3rdParty/Halide/install/include -I$TIRAMISU_ROOT/include -I$TIRAMISU_ROOT/3rdParty/isl/include -I$TIRAMISU_HERMESII_PATH/include -fvisibility-inlines-hidden -ftree-vectorize -fPIC -fstack-protector-strong -fno-plt -O3 -ffunction-sections -pipe -isystem $CONDA_ENV/include -ldl -g -fno-rtti -lpthread -std=c++17 -MD -MT ${{FUNC_NAME}}.cpp.o -MF ${{FUNC_NAME}}.cpp.o.d -o ${{FUNC_NAME}}.cpp.o -c ${{FUNC_NAME}}_server.cpp && $CXX -fvisibility-inlines-hidden -ftree-vectorize -fPIC -fstack-protector-strong -fno-plt -O3 -ffunction-sections -pipe -isystem $CONDA_ENV/include -ldl -g -fno-rtti -lpthread ${{FUNC_NAME}}.cpp.o -o ${{FUNC_NAME}}_server -L$TIRAMISU_ROOT/build -L$TIRAMISU_ROOT/3rdParty/Halide/install/lib64 -L$TIRAMISU_ROOT/3rdParty/isl/build/lib -Wl,-rpath,$TIRAMISU_ROOT/build:$TIRAMISU_ROOT/3rdParty/Halide/install/lib64:$TIRAMISU_ROOT/3rdParty/isl/build/lib:$TIRAMISU_HERMESII_PATH/lib $TIRAMISU_HERMESII_PATH/lib/libHermesII.so -ltiramisu -ltiramisu_auto_scheduler -lHalide -lisl -lsqlite3 $CONDA_ENV/lib/libz.so"
compileCommand = f"cd {BaseConfig.base_config.workspace} && {env_vars} && export FUNC_NAME={self.tiramisu_program.name} && $CXX -I$TIRAMISU_ROOT/3rdParty/Halide/install/include -I$TIRAMISU_ROOT/include -I$TIRAMISU_ROOT/3rdParty/isl/include -I$TIRAMISU_HERMESII_PATH/include -fvisibility-inlines-hidden -ftree-vectorize -fPIC -fstack-protector-strong -fno-plt -O3 -ffunction-sections -pipe -isystem $CONDA_ENV/include -ldl -g -fno-rtti -lpthread -std=c++17 -MD -MT ${{FUNC_NAME}}.cpp.o -MF ${{FUNC_NAME}}.cpp.o.d -o ${{FUNC_NAME}}.cpp.o -c ${{FUNC_NAME}}_server.cpp && $CXX -fvisibility-inlines-hidden -ftree-vectorize -fPIC -fstack-protector-strong -fno-plt -O3 -ffunction-sections -pipe -isystem $CONDA_ENV/include -ldl -g -fno-rtti -lpthread ${{FUNC_NAME}}.cpp.o -o ${{FUNC_NAME}}_server -L$TIRAMISU_ROOT/build -L$TIRAMISU_ROOT/3rdParty/Halide/install/lib64 -L$TIRAMISU_ROOT/3rdParty/isl/build/lib -Wl,-rpath,$TIRAMISU_ROOT/build:$TIRAMISU_ROOT/3rdParty/Halide/install/lib64:$TIRAMISU_ROOT/3rdParty/isl/build/lib:$TIRAMISU_HERMESII_PATH/lib $TIRAMISU_HERMESII_PATH/lib/libHermesII.so -ltiramisu -ltiramisu_auto_scheduler -lHalide -lisl -lsqlite3 $CONDA_ENV/lib/libz.so" # noqa: E501

# run the command and retrieve the execution status
try:
Expand All @@ -177,7 +185,7 @@ def run(
assert operation in [
"execution",
"legality",
], f"Invalid operation {operation}. Valid operations are: execution, legality, annotations"
], f"Invalid operation {operation}. Valid operations are: execution, legality, annotations" # noqa: E501

env_vars = " && ".join(
[
Expand All @@ -186,7 +194,7 @@ def run(
]
)

command = f'{env_vars} && cd {BaseConfig.base_config.workspace} && NB_EXEC={nbr_executions} ./{self.tiramisu_program.name}_server {operation} "{schedule or ""}"'
command = f'{env_vars} && cd {BaseConfig.base_config.workspace} && NB_EXEC={nbr_executions} ./{self.tiramisu_program.name}_server {operation} "{schedule or ""}"' # noqa: E501

# run the command and retrieve the execution status
try:
Expand All @@ -207,7 +215,7 @@ def get_annotations(self):
]
)

command = f"{env_vars} && cd {BaseConfig.base_config.workspace} && ./{self.tiramisu_program.name}_server annotations"
command = f"{env_vars} && cd {BaseConfig.base_config.workspace} && ./{self.tiramisu_program.name}_server annotations" # noqa: E501

# run the command and retrieve the execution status
try:
Expand Down
46 changes: 30 additions & 16 deletions athena/tiramisu/schedule.py
9E12
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ class Schedule:
The list of optimizations to be applied to the Tiramisu program.
"""

def __init__(self, tiramisu_program: TiramisuProgram | None = None) -> None:
def __init__(
self, tiramisu_program: TiramisuProgram | None = None
) -> None:
self.tiramisu_program = tiramisu_program
self.optims_list: List[TiramisuAction] = []
if tiramisu_program:
Expand All @@ -43,7 +45,8 @@ def set_tiramisu_program(self, tiramisu_program: TiramisuProgram) -> None:

def add_optimizations(self, list_optim_cmds: List[TiramisuAction]) -> None:
"""
Adds a list of optimizations to the schedule while maintaining the schedule tree. The order of the optimizations in the list is important.
Adds a list of optimizations to the schedule while maintaining the
schedule tree. The order of the optimizations in the list is important.

Parameters
----------
Expand All @@ -61,7 +64,8 @@ def add_optimizations(self, list_optim_cmds: List[TiramisuAction]) -> None:

self.optims_list.append(optim_cmd)

# Fusion, distribution and tiling are special cases, we need to get the new tree with the new fusion levels
# Fusion, distribution and tiling are special cases,
# we need to get the new tree with the new fusion levels
if (
optim_cmd.is_fusion()
or optim_cmd.is_distribution()
Expand All @@ -87,7 +91,8 @@ def execute(
Parameters
----------
`nb_exec_times` : int
The number of times the Tiramisu program will be executed after applying the schedule.
The number of times the Tiramisu program will be executed after
applying the schedule.
Returns
-------
The execution time of the Tiramisu program after applying the schedule.
Expand All @@ -97,17 +102,19 @@ def execute(

if self.tiramisu_program.server:
result = self.tiramisu_program.server.run(
operation="execution", schedule=self, nbr_executions=nb_exec_times
operation="execution",
schedule=self,
nbr_executions=nb_exec_times,
)
if result.legality == False:
if result.legality is False:
raise Exception("Schedule is not legal")

return result.exec_times

if self.legality is None and self.optims_list:
self.is_legal()

if self.legality == False:
if self.legality is False:
raise Exception("Schedule is not legal")

return CompilingService.get_cpu_exec_times(
Expand Down Expand Up @@ -143,15 +150,19 @@ def is_legal(self, with_ast: bool = False) -> bool:
for action in self.optims_list:
if action.type == TiramisuActionType.SKEWING:
if action.params[2] == 0:
factors = result.additional_info.replace("skewing_factors:", "").split(",")
factors = result.additional_info.replace(
"skewing_factors:", ""
).split(",")
factors = [int(factor) for factor in factors]
action.params[2] = factors[0]
action.params[3] = factors[1]
action.factors = factors
action.set_string_representations(self.tree)
return result.legality

legality, new_tree = CompilingService.compile_legality(self, with_ast=with_ast)
legality, new_tree = CompilingService.compile_legality(
self, with_ast=with_ast
)

assert isinstance(legality, bool)
self.legality = legality
Expand All @@ -176,7 +187,9 @@ def update_tree_from_isl_ast(self):
isl_ast_str = CompilingService.compile_isl_ast_tree(
tiramisu_program=self.tiramisu_program, schedule=self
)
self.tree = TiramisuTree.from_isl_ast_string_list(isl_ast_str.split("\n"))
self.tree = TiramisuTree.from_isl_ast_string_list(
isl_ast_str.split("\n")
)

@classmethod
def from_sched_str(
Expand Down Expand Up @@ -204,7 +217,8 @@ def from_sched_str(
)

elif optimization_str[0] == "U":
# extract loop level, factor and comps using U\(L(\d),(\d+),comps=\[([\w',]*)\]\)
# extract loop level, factor and comps using
# U\(L(\d),(\d+),comps=\[([\w',]*)\]\)
regex = r"U\(L(\d),(\d+),comps=\[([\w', ]*)\]\)"
match = re.match(regex, optimization_str)
if match:
Expand Down Expand Up @@ -277,9 +291,7 @@ def from_sched_str(
]
)
elif optimization_str[:2] == "T3":
regex = (
r"T3\(L(\d),L(\d),L(\d),(\d+),(\d+),(\d+),comps=\[([\w', ]*)\]\)"
)
regex = r"T3\(L(\d),L(\d),L(\d),(\d+),(\d+),(\d+),comps=\[([\w', ]*)\]\)" # noqa: E501
match = re.match(regex, optimization_str)
if match:
outer_loop_level = int(match.group(1))
Expand All @@ -305,7 +317,9 @@ def from_sched_str(
]
)
elif optimization_str[0] == "S":
regex = r"S\(L(\d),L(\d),(-?\d+),(-?\d+),comps=\[([\w', ]*)\]\)"
regex = (
r"S\(L(\d),L(\d),(-?\d+),(-?\d+),comps=\[([\w', ]*)\]\)"
)
match = re.match(regex, optimization_str)
if match:
outer_loop_level = int(match.group(1))
Expand Down Expand Up @@ -344,7 +358,7 @@ def from_sched_str(
]
)
elif optimization_str[0] == "D":
regex = r"D\(L(\d),comps=\[([\w', ]*)\],distribution=([\[\]'\w, ]*)\)"
regex = r"D\(L(\d),comps=\[([\w', ]*)\],distribution=([\[\]'\w, ]*)\)" # noqa: E501
match = re.match(regex, optimization_str)
if match:
loop_level = int(match.group(1))
Expand Down
23 changes: 22 additions & 1 deletion athena/tiramisu/tiramisu_actions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,26 @@
from .tiling_2d import Tiling2D
from .tiling_3d import Tiling3D
from .tiling_general import TilingGeneral
from .tiramisu_action import CannotApplyException, TiramisuAction, TiramisuActionType
from .tiramisu_action import (
CannotApplyException,
TiramisuAction,
TiramisuActionType,
)
from .unrolling import Unrolling

__all__ = [
"TiramisuAction",
"TiramisuActionType",
"CannotApplyException",
"Interchange",
"Tiling2D",
"Tiling3D",
"TilingGeneral",
"Parallelization",
"Skewing",
"Unrolling",
"Fusion",
"Reversal",
"Expansion",
"Distribution",
]
Loading
0