8000 add SSTORE and SLOAD to traces by 0xkarmacoma · Pull Request #456 · a16z/halmos · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

add SSTORE and SLOAD to traces #456

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/halmos/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
unsat,
)

import halmos.traces

from .build import (
build_output_iterator,
import_libs,
Expand Down Expand Up @@ -401,6 +403,9 @@ def run_test(ctx: FunctionContext) -> TestResult:
if args.verbose >= 1:
print(f"Executing {funname}")

# set the config for every trace rendered in this test
halmos.traces.config_context.set(args)

#
# prepare calldata
#
Expand Down Expand Up @@ -742,6 +747,7 @@ def run_contract(ctx: ContractContext) -> list[TestResult]:
contract_ctx=ctx,
)

halmos.traces.config_context.set(setup_config)
setup_ex = setup(setup_ctx)
except Exception as err:
error(f"{setup_info.sig} failed: {type(err).__name__}: {err}")
Expand Down
41 changes: 37 additions & 4 deletions src/halmos/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from collections.abc import Callable, Generator
from dataclasses import MISSING, dataclass, fields
from dataclasses import field as dataclass_field
from enum import Enum
from pathlib import Path
from typing import Any

Expand All @@ -26,6 +27,12 @@
)


class TraceEvent(Enum):
LOG = "LOG"
SSTORE = "SSTORE"
SLOAD = "SLOAD"


def find_venv_root() -> Path | None:
# If the environment variable is set, use that
if "VIRTUAL_ENV" in os.environ:
Expand Down Expand Up @@ -88,9 +95,28 @@ def parse_csv(values: str, sep: str = ",") -> Generator[Any, None, None]:
return (x for _x in values.split(sep) if (x := _x.strip()))


class ParseCSV(argparse.Action):
class ParseCSVTraceEvent(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
values = ParseCSV.parse(values)
values = ParseCSVTraceEvent.parse(values)
setattr(namespace, self.dest, values)

@staticmethod
def parse(values: str) -> list[TraceEvent]:
# empty list is ok
try:
return [TraceEvent(x) for x in parse_csv(values)]
except ValueError as e:
valid = ", ".join([e.value for e in TraceEvent])
raise ValueError(f"the list of valid trace events is: {valid}") from e

@staticmethod
def unparse(values: list[TraceEvent]) -> str:
return ",".join([x.value for x in values])


class ParseCSVInt(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
values = ParseCSVInt.parse(values)
setattr(namespace, self.dest, values)

@staticmethod
Expand Down Expand Up @@ -277,14 +303,14 @@ class Config:
help="set default lengths for dynamic-sized arrays (excluding bytes and string) not specified in --array-lengths",
global_default="0,1,2",
metavar="LENGTH1,LENGTH2,...",
action=ParseCSV,
action=ParseCSVInt,
)

default_bytes_lengths: str = arg(
help="set default lengths for bytes and string types not specified in --array-lengths",
global_default="0,65,1024", # 65 is ECDSA signature size
metavar="LENGTH1,LENGTH2,...",
action=ParseCSV,
action=ParseCSVInt,
)

storage_layout: str = arg(
Expand Down Expand Up @@ -424,6 +450,13 @@ class Config:
group=debugging,
)

trace_events: str = arg(
help="include specific events in traces",
global_default=",".join([e.value for e in TraceEvent]),
metavar="EVENT1,EVENT2,...",
action=ParseCSVTraceEvent,
)

### Build options

forge_build_out: str = arg(
Expand Down
23 changes: 21 additions & 2 deletions src/halmos/sevm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
ForwardRef,
Optional,
TypeVar,
Union,
)

import rich
Expand Down Expand Up @@ -344,6 +345,20 @@ class EventLog:
data: Bytes | None


@dataclass(frozen=True)
class StorageWrite:
address: Address
slot: Word
value: Word


@dataclass(frozen=True)
class StorageRead:
address: Address
slot: Word
value: Word


@dataclass(frozen=True)
class Message:
target: Address
Expand Down Expand Up @@ -387,7 +402,7 @@ class CallOutput:
# - gas_left


TraceElement = ForwardRef("CallContext") | EventLog
TraceElement = Union["CallContext", EventLog, StorageRead, StorageWrite]


@dataclass
Expand Down Expand Up @@ -2037,9 +2052,13 @@ def mk_storagedata(self) -> StorageData:
return self.storage_model.mk_storagedata()

def sload(self, ex: Exec, addr: Any, loc: Word) -> Word:
return self.storage_model.load(ex, addr, loc)
val = self.storage_model.load(ex, addr, loc)
ex.context.trace.append(StorageRead(addr, loc, val))
return val

def sstore(self, ex: Exec, addr: Any, loc: Any, val: Any) -> None:
ex.context.trace.append(StorageWrite(addr, loc, val))

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you want to add this event even in the case of raising WriteInStaticContext below?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I think it's going to be useful for debugging - so you know what it tried to write that caused the revert

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

example
image

if ex.message().is_static:
raise WriteInStaticContext(ex.context_str())

Expand Down
69 changes: 58 additions & 11 deletions src/halmos/traces.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,29 @@
import io
import sys
from contextvars import ContextVar

from z3 import Z3_OP_CONCAT, BitVecNumRef, BitVecRef, is_app

from halmos.bytevec import ByteVec
from halmos.config import Config, TraceEvent
from halmos.exceptions import HalmosException
from halmos.mapper import DeployAddressMapper, Mapper
from halmos.sevm import CallContext, EventLog, mnemonic
from halmos.sevm import CallContext, EventLog, StorageRead, StorageWrite, mnemonic
from halmos.utils import (
Address,
byte_length,
cyan,
green,
hexify,
is_bv,
magenta,
red,
unbox_int,
yellow,
)

config_context: ContextVar[Config | None] = ContextVar("config", default=None)


def rendered_initcode(context: CallContext) -> str:
message = context.message
Expand Down Expand Up @@ -73,6 +79,16 @@ def render_output(context: CallContext, file=sys.stdout) -> None:
)


def rendered_address(addr: Address) -> str:
addr = unbox_int(addr)
addr_str = str(addr) if is_bv(addr) else hex(addr)

# check if we have a contract name for this address in our deployment mapper
addr_str = DeployAddressMapper().get_deployed_contract(addr_str)

return addr_str


def rendered_log(log: EventLog) -> str:
opcode_str = f"LOG{len(log.topics)}"
topics = [
Expand All @@ -84,6 +100,28 @@ def rendered_log(log: EventLog) -> str:
return f"{opcode_str}({args_str})"


def rendered_slot(slot: Address) -> str:
slot = unbox_int(slot)

if is_bv(slot):
return magenta(hexify(slot))

if slot < 2**16:
return magenta(str(slot))

return magenta(hex(slot))


def rendered_sstore(update: StorageWrite) -> str:
slot_str = rendered_slot(update.slot)
return f"{cyan('SSTORE')} @{slot_str} ← {hexify(update.value)}"


def rendered_sload(read: StorageRead) -> str:
slot_str = rendered_slot(read.slot)
return f"{cyan('SLOAD')} @{slot_str} → {hexify(read.value)}"


def rendered_trace(context: CallContext) -> str:
with io.StringIO() as output:
render_trace(context, file=output)
Expand All @@ -106,11 +144,12 @@ def rendered_calldata(calldata: ByteVec, contract_name: str | None = None) -> st


def render_trace(context: CallContext, file=sys.stdout) -> None:
config: Config = config_context.get()
if config is None:
raise HalmosException("config not set")

message = context.message
addr = unbox_int(message.target)
addr_str = str(addr) if is_bv(addr) else hex(addr)
# check if we have a contract name for this address in our deployment mapper
addr_str = DeployAddressMapper().get_deployed_contract(addr_str)
addr_str = rendered_address(message.target)

value = unbox_int(message.value)
value_str = f" (value: {value})" if is_bv(value) or value > 0 else ""
Expand Down Expand Up @@ -147,12 +186,20 @@ def render_trace(context: CallContext, file=sys.stdout) -> None:

log_indent = (context.depth + 1) * " "
for trace_element in context.trace:
if isinstance(trace_element, CallContext):
render_trace(trace_element, file=file)
elif isinstance(trace_element, EventLog):
print(f"{log_indent}{rendered_log(trace_element)}", file=file)
else:
raise HalmosException(f"unexpected trace element: {trace_element}")
match trace_element:
case CallContext():
render_trace(trace_element, file=file)
case EventLog():
if TraceEvent.LOG in config.trace_events:
print(f"{log_indent}{rendered_log(trace_element)}", file=file)
case StorageRead():
if TraceEvent.SLOAD in config.trace_events:
print(f"{log_indent}{rendered_sload(trace_element)}", file=file)
case StorageWrite():
if TraceEvent.SSTORE in config.trace_events:
print(f"{log_indent}{rendered_sstore(trace_element)}", file=file)
case _:
raise HalmosException(f"unexpected trace element: {trace_element}")

render_output(context, file=file)

Expand Down
2 changes: 1 addition & 1 deletion src/halmos/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@ def cyan(text: str) -> str:


def magenta(text: str) -> str:
return f"\033[35m{text}\033[0m"
return f"\033[95m{text}\033[0m"


color_good = green
Expand Down
30 changes: 15 additions & 15 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from halmos.config import (
Config,
ParseArrayLengths,
ParseCSV,
ParseCSVInt,
ParseErrorCodes,
arg_parser,
default_config,
Expand Down Expand Up @@ -192,21 +192,21 @@ def test_config_pickle(config, parser):

def test_parse_csv():
with pytest.raises(ValueError):
ParseCSV.parse("")
ParseCSV.parse(" ")
ParseCSV.parse(",")
assert ParseCSV.parse("0") == [0]
assert ParseCSV.parse("0,") == [0]
assert ParseCSV.parse("1,2,3") == [1, 2, 3]
assert ParseCSV.parse("1,2,3,") == [1, 2, 3]
assert ParseCSV.parse(" 1 , 2 , 3 ") == [1, 2, 3]
assert ParseCSV.parse(" , 1 , 2 , 3 , ") == [1, 2, 3]
ParseCSVInt.parse("")
ParseCSVInt.parse(" ")
ParseCSVInt.parse(",")
assert ParseCSVInt.parse("0") == [0]
assert ParseCSVInt.parse("0,") == [0]
assert ParseCSVInt.parse("1,2,3") == [1, 2, 3]
assert ParseCSVInt.parse("1,2,3,") == [1, 2, 3]
assert ParseCSVInt.parse(" 1 , 2 , 3 ") == [1, 2, 3]
assert ParseCSVInt.parse(" , 1 , 2 , 3 , ") == [1, 2, 3]


def test_unparse_csv():
assert ParseCSV.unparse([]) == ""
assert ParseCSV.unparse([0]) == "0"
assert ParseCSV.unparse([1, 2, 3]) == "1,2,3"
assert ParseCSVInt.unparse([]) == ""
assert ParseCSVInt.unparse([0]) == "0"
assert ParseCSVInt.unparse([1, 2, 3]) == "1,2,3"


def test_parse_csv_roundtrip():
Expand All @@ -216,8 +216,8 @@ def test_parse_csv_roundtrip():
]

for original in test_cases:
unparsed = ParseCSV.unparse(original)
parsed = ParseCSV.parse(unparsed)
unparsed = ParseCSVInt.unparse(original)
parsed = ParseCSVInt.parse(unparsed)
assert parsed == original, f"Roundtrip failed for {original}"


Expand Down
Loading
2A6A
0