diff --git a/feynsum-sml/run.sh b/feynsum-sml/run.sh index 8fa783e..4971d0a 100755 --- a/feynsum-sml/run.sh +++ b/feynsum-sml/run.sh @@ -1,5 +1,5 @@ -./main.mpl @mpl procs 72 set-affinity megablock-threshold 14 cc-threshold-ratio 1.1 collection-threshold-ratio 2.0 max-cc-depth 1 -- -scheduler gfq -input $@ -# cc-threshold-ratio 1.1 collection-threshold-ratio 2.0 - - -# ./all-main.mpl @mpl procs 72 set-affinity megablock-threshold 14 cc-threshold-ratio 1.25 max-cc-depth 1 -- -sim query-bfs -input $@ +./main.mpl @mpl procs 20 set-affinity megablock-threshold 14 cc-threshold-ratio 1.1 collection-threshold-ratio 2.0 max-cc-depth 1 -- -dense-thresh 1.1 --scheduler-disable-fusion -scheduler $1 -input $2 +# -scheduler-max-branching-stride 1 +# --scheduler-disable-fusion +# -dense-thresh 0.75 +# -pull-thresh 0.01 diff --git a/feynsum-sml/src/FullSimBFS.sml b/feynsum-sml/src/FullSimBFS.sml index 393a507..ecaecbd 100644 --- a/feynsum-sml/src/FullSimBFS.sml +++ b/feynsum-sml/src/FullSimBFS.sml @@ -1,38 +1,36 @@ functor FullSimBFS (structure B: BASIS_IDX structure C: COMPLEX - structure SST: SPARSE_STATE_TABLE + structure HS: HYBRID_STATE structure G: GATE - sharing B = SST.B = G.B - sharing C = SST.C = G.C + sharing B = HS.B = G.B + sharing C = HS.C = G.C + val disableFusion: bool + val maxBranchingStride: int + val gateScheduler: string val blockSize: int val maxload: real - val gateScheduler: GateScheduler.t val doMeasureZeros: bool val denseThreshold: real val pullThreshold: real): sig - val run: Circuit.t + val run: DataFlowGraph.t -> {result: (B.t * C.t) option DelayedSeq.t, counts: int Seq.t} end = struct - structure DS = DenseState (structure C = C structure B = B) - structure Expander = ExpandState (structure B = B structure C = C - structure SST = SST - structure DS = DS + structure HS = HS structure G = G val denseThreshold = denseThreshold val blockSize = blockSize val maxload = maxload val pullThreshold = pullThreshold) - val bits = Seq.fromList [ (*"▏",*)"▎", "▍", "▌", "▊"] fun fillBar width x = @@ -61,9 +59,9 @@ struct let val ss = case s of - Expander.Sparse sst => SST.unsafeViewContents sst - | Expander.Dense ds => DS.unsafeViewContents ds - | Expander.DenseKnownNonZeroSize (ds, _) => DS.unsafeViewContents ds + HS.Sparse sst => HS.SST.unsafeViewContents sst + | HS.Dense ds => HS.DS.unsafeViewContents ds + | HS.DenseKnownNonZeroSize (ds, _) => HS.DS.unsafeViewContents ds in Util.for (0, DelayedSeq.length ss) (fn i => case DelayedSeq.nth ss i of @@ -76,22 +74,45 @@ struct end) end - - fun run {numQubits, gates} = + structure DGFQ = DynSchedFinishQubitWrapper + (structure B = B + structure C = C + structure HS = HS + val maxBranchingStride = maxBranchingStride + val disableFusion = disableFusion) + + structure DGI = DynSchedInterference + (structure B = B + structure C = C + structure HS = HS + val maxBranchingStride = maxBranchingStride + val disableFusion = disableFusion) + structure DGN = DynSchedNaive + (structure B = B + structure C = C + structure HS = HS + val maxBranchingStride = maxBranchingStride + val disableFusion = disableFusion) + + val gateSched = + case gateScheduler of + "naive" => DGN.choose + | "gfq" => DGFQ.choose + | "interference" => DGI.choose + | _ => raise Fail ("Unknown scheduler '" ^ gateScheduler ^ "'\n") + + fun run dfg (*{numQubits, gates}*) = let - val gates = Seq.map G.fromGateDefn gates + val gates = Seq.map G.fromGateDefn (#gates dfg) + val numQubits = #numQubits dfg fun gate i = Seq.nth gates i val depth = Seq.length gates + val dgstate = DataFlowGraphUtil.initState dfg - val gateSchedulerPickNextGates = gateScheduler - { numQubits = numQubits - , numGates = depth - , gateTouches = #touches o gate - , gateIsBranching = (fn i => - case #action (gate i) of - G.NonBranching _ => false - | _ => true) - } + val pickNextGate = + let val f = gateSched dfg in + fn (s, g) => f (s, g) + end (* val _ = if numQubits > 63 then raise Fail "whoops, too many qubits" else () *) @@ -134,77 +155,74 @@ struct density end - - fun loop numGateApps gatesVisitedSoFar counts prevNonZeroSize state = - if gatesVisitedSoFar >= depth then - let - val (nonZeros, numNonZeros) = - case state of - Expander.Sparse sst => - (SST.unsafeViewContents sst, SST.nonZeroSize sst) - | Expander.Dense ds => - (DS.unsafeViewContents ds, DS.nonZeroSize ds) - | Expander.DenseKnownNonZeroSize (ds, nz) => - (DS.unsafeViewContents ds, nz) - - (* val _ = dumpState numQubits state *) - - val density = - dumpDensity (gatesVisitedSoFar, numNonZeros, NONE, NONE) - in - print "\n"; - (numGateApps, nonZeros, Seq.fromRevList (numNonZeros :: counts)) - end - - else - let - (* val _ = dumpState numQubits state *) - - val theseGates = gateSchedulerPickNextGates () - val _ = - if Seq.length theseGates > 0 then - () - else - raise Fail "FullSimBFS: gate scheduler returned empty sequence" - - (* val _ = print - ("visiting: " ^ Seq.toString Int.toString theseGates ^ "\n") *) - - val theseGates = Seq.map (Seq.nth gates) theseGates - val numGatesVisitedHere = Seq.length theseGates - val ({result, method, numNonZeros, numGateApps = apps}, tm) = - Util.getTime (fn () => - Expander.expand - { gates = theseGates - , numQubits = numQubits - , maxNumStates = maxNumStates - , state = state - , prevNonZeroSize = prevNonZeroSize - }) - - val seconds = Time.toReal tm - val millions = Real.fromInt apps / 1e6 - val throughput = millions / seconds - val throughputStr = Real.fmt (StringCvt.FIX (SOME 2)) throughput - val density = - dumpDensity (gatesVisitedSoFar, numNonZeros, NONE, NONE) - val _ = print - (" hop " ^ leftPad 3 (Int.toString numGatesVisitedHere) ^ " " - ^ rightPad 11 method ^ " " - ^ Real.fmt (StringCvt.FIX (SOME 4)) seconds ^ "s throughput " - ^ throughputStr ^ "\n") - in - loop (numGateApps + apps) (gatesVisitedSoFar + numGatesVisitedHere) - (numNonZeros :: counts) numNonZeros result - end - - - val initialState = Expander.Sparse - (SST.singleton {numQubits = numQubits} (B.zeros, C.defaultReal 1.0)) - - val (numGateApps, finalState, counts) = loop 0 0 [] 1 initialState + fun getNumZeros state = + case state of + HS.Sparse sst => HS.SST.zeroSize sst + | HS.Dense ds => 0 (*raise Fail "Can't do dense stuff!"*) + (*DS.unsafeViewContents ds, DS.nonZeroSize ds, TODO exception*) + | HS.DenseKnownNonZeroSize (ds, nz) => 0 (*raise Fail "Can't do dense stuff!"*) + (*DS.unsafeViewContents ds, nz, TODO exception*) + + val initialState = HS.Sparse + (HS.SST.singleton {numQubits = numQubits} (B.zeros, C.defaultReal 1.0)) + + fun runloop () = + DataFlowGraphUtil.scheduleWithOracle' + + (* data flow graph *) + dfg + + (* gate is branching *) + (fn i => G.expectBranching (Seq.nth gates i)) + + (* select gate *) + (fn ((state, numGateApps, counts, gatesVisitedSoFar), gates) => pickNextGate (state, gates)) + + (* disable fusion? *) + disableFusion + + (* if fusion enabled, what's the max # of branching gates to fuse? *) + maxBranchingStride + + (* apply gate fusion seq, updating state *) + (fn ((state, numGateApps, counts, gatesVisitedSoFar), theseGates) => + let val numGatesVisitedHere = Seq.length theseGates + val ({result, method, numNonZeros, numGateApps = apps}, tm) = + Util.getTime (fn () => + Expander.expand + { gates = Seq.map (Seq.nth gates) theseGates + , numQubits = numQubits + , maxNumStates = maxNumStates + , state = state + , prevNonZeroSize = (case counts of h :: t => h | nil => 1) + }) + + val seconds = Time.toReal tm + val millions = Real.fromInt apps / 1e6 + val throughput = millions / seconds + val throughputStr = Real.fmt (StringCvt.FIX (SOME 2)) throughput + val density = + dumpDensity (gatesVisitedSoFar, numNonZeros, SOME (getNumZeros state), NONE) + val _ = print + (" hop " ^ leftPad 3 (Int.toString numGatesVisitedHere) ^ " " + ^ rightPad 11 method ^ " " + ^ Real.fmt (StringCvt.FIX (SOME 4)) seconds ^ "s throughput " + ^ throughputStr ^ "\n") + in + (result, numGateApps + apps, numNonZeros :: counts, gatesVisitedSoFar + numGatesVisitedHere) + end + ) + + (* initial state *) + (initialState, 0, [], 0) + + val (finalState, numGateApps, counts, gatesVisited) = runloop () + val nonZeros = case finalState of + HS.Sparse sst => HS.SST.unsafeViewContents sst + | HS.Dense ds => HS.DS.unsafeViewContents ds + | HS.DenseKnownNonZeroSize (ds, nz) => HS.DS.unsafeViewContents ds val _ = print ("gate app count " ^ Int.toString numGateApps ^ "\n") in - {result = finalState, counts = counts} + {result = nonZeros, counts = Seq.fromList counts} end end diff --git a/feynsum-sml/src/MkMain.sml b/feynsum-sml/src/MkMain.sml index 198fa13..f8667c3 100644 --- a/feynsum-sml/src/MkMain.sml +++ b/feynsum-sml/src/MkMain.sml @@ -1,9 +1,11 @@ functor MkMain (structure C: COMPLEX structure B: BASIS_IDX + val disableFusion: bool + val maxBranchingStride: int val blockSize: int val maxload: real - val gateScheduler: GateScheduler.t + val gateScheduler: string val doMeasureZeros: bool val denseThreshold: real val pullThreshold: real) = @@ -13,13 +15,26 @@ struct structure G = Gate (structure B = B structure C = C) + structure SSTLocked = SparseStateTableLockedSlots (structure B = B structure C = C) + structure SSTLockFree = SparseStateTable (structure B = B structure C = C) + structure DS = DenseState (structure B = B structure C = C) + structure HSLocked = HybridState (structure B = B + structure C = C + structure SST = SSTLocked + structure DS = DS) + structure HSLockFree = HybridState (structure B = B + structure C = C + structure SST = SSTLockFree + structure DS = DS) + structure BFSLocked = FullSimBFS (structure B = B structure C = C - structure SST = - SparseStateTableLockedSlots (structure B = B structure C = C) + structure HS = HSLocked structure G = G + val disableFusion = disableFusion + val maxBranchingStride = maxBranchingStride val blockSize = blockSize val maxload = maxload val gateScheduler = gateScheduler @@ -31,8 +46,10 @@ struct FullSimBFS (structure B = B structure C = C - structure SST = SparseStateTable (structure B = B structure C = C) + structure HS = HSLockFree structure G = G + val disableFusion = disableFusion + val maxBranchingStride = maxBranchingStride val blockSize = blockSize val maxload = maxload val gateScheduler = gateScheduler @@ -50,23 +67,23 @@ struct fun main (inputName, circuit) = let - val numQubits = Circuit.numQubits circuit + val numQubits = #numQubits circuit val impl = CLA.parseString "impl" "lockfree" val output = CLA.parseString "output" "" val outputDensities = CLA.parseString "output-densities" "" val _ = print ("impl " ^ impl ^ "\n") - val sim = + fun sim () = case impl of - "lockfree" => BFSLockfree.run - | "locked" => BFSLocked.run + "lockfree" => BFSLockfree.run circuit + | "locked" => BFSLocked.run circuit | _ => Util.die ("unknown impl " ^ impl ^ "; valid options are: locked, lockfree\n") - val {result, counts} = Benchmark.run "full-sim-bfs" (fn _ => sim circuit) + val {result, counts} = Benchmark.run "full-sim-bfs" (fn _ => sim ()) val counts = Seq.map IntInf.fromInt counts val maxNumStates = IntInf.pow (2, numQubits) @@ -146,8 +163,8 @@ struct print (String.concatWith "," [ name - , Int.toString (Circuit.numQubits circuit) - , Int.toString (Circuit.numGates circuit) + , Int.toString (#numQubits circuit) + , Int.toString (Seq.length (#gates circuit)) , Real.fmt (StringCvt.FIX (SOME 12)) (Rat.approx maxDensity) , Real.fmt (StringCvt.FIX (SOME 12)) (Rat.approx avgDensity) ] ^ "\n") diff --git a/feynsum-sml/src/common/ApplyUntilFailure.sml b/feynsum-sml/src/common/ApplyUntilFailure.sml deleted file mode 100644 index d12c98f..0000000 --- a/feynsum-sml/src/common/ApplyUntilFailure.sml +++ /dev/null @@ -1,40 +0,0 @@ -structure ApplyUntilFailure: -sig - datatype 'a element_result = Success | Failure of 'a - - val doPrefix: {grain: int, acceleration: real} - -> int * int - -> (int -> 'a element_result) - -> {numApplied: int, failed: 'a Seq.t} -end = -struct - - datatype 'a element_result = Success | Failure of 'a - - fun doPrefix {grain, acceleration} (start, stop) doElem = - let - fun loop numApplied (lo, hi) = - if lo >= hi then - {numApplied = numApplied, failed = Seq.empty ()} - else - let - val resultHere = SeqBasis.tabFilter grain (lo, hi) (fn i => - case doElem i of - Success => NONE - | Failure x => SOME x) - - val widthHere = hi - lo - val numApplied' = numApplied + widthHere - - val desiredWidth = Real.ceil (acceleration * Real.fromInt widthHere) - val lo' = hi - val hi' = Int.min (hi + desiredWidth, stop) - in - if Array.length resultHere = 0 then loop numApplied' (lo', hi') - else {numApplied = numApplied', failed = ArraySlice.full resultHere} - end - in - loop 0 (start, stop) - end - -end diff --git a/feynsum-sml/src/common/Circuit.sml b/feynsum-sml/src/common/Circuit.sml index 3020933..d45b3c6 100644 --- a/feynsum-sml/src/common/Circuit.sml +++ b/feynsum-sml/src/common/Circuit.sml @@ -3,7 +3,7 @@ sig type circuit = {numQubits: int, gates: GateDefn.t Seq.t} type t = circuit - val toString: circuit -> string + (*val toString: circuit -> string*) val numGates: circuit -> int val numQubits: circuit -> int @@ -214,102 +214,4 @@ struct {numQubits = numQubits, gates = Seq.map convertGate gates} end - - fun toString {numQubits, gates} = - let - val header = "qreg q[" ^ Int.toString numQubits ^ "];\n" - - fun qi i = - "q[" ^ Int.toString i ^ "]" - - fun doOther {name, params, args} = - let - val pstr = - if Seq.length params = 0 then - "()" - else - "(" - ^ - Seq.iterate (fn (acc, e) => acc ^ ", " ^ Real.toString e) - (Real.toString (Seq.nth params 0)) (Seq.drop params 1) ^ ")" - - val front = name ^ pstr - - val args = - if Seq.length args = 0 then - "" - else - Seq.iterate (fn (acc, i) => acc ^ ", " ^ qi i) - (qi (Seq.nth args 0)) (Seq.drop args 1) - in - front ^ " " ^ args - end - - fun gateToString gate = - case gate of - GateDefn.PauliY i => "y " ^ qi i - | GateDefn.PauliZ i => "z " ^ qi i - | GateDefn.Hadamard i => "h " ^ qi i - | GateDefn.T i => "t " ^ qi i - | GateDefn.Tdg i => "tdg " ^ qi i - | GateDefn.SqrtX i => "sx " ^ qi i - | GateDefn.Sxdg i => "sxdg " ^ qi i - | GateDefn.S i => "s " ^ qi i - | GateDefn.Sdg i => "sdg " ^ qi i - | GateDefn.X i => "x " ^ qi i - | GateDefn.CX {control, target} => "cx " ^ qi control ^ ", " ^ qi target - | GateDefn.CZ {control, target} => "cz " ^ qi control ^ ", " ^ qi target - | GateDefn.CCX {control1, control2, target} => - "ccx " ^ qi control1 ^ ", " ^ qi control2 ^ ", " ^ qi target - | GateDefn.Phase {target, rot} => - "phase(" ^ Real.toString rot ^ ") " ^ qi target - | GateDefn.CPhase {control, target, rot} => - "cphase(" ^ Real.toString rot ^ ") " ^ qi control ^ ", " ^ qi target - | GateDefn.FSim {left, right, theta, phi} => - "fsim(" ^ Real.toString theta ^ ", " ^ Real.toString phi ^ ") " - ^ qi left ^ ", " ^ qi right - | GateDefn.RZ {rot, target} => - "rz(" ^ Real.toString rot ^ ") " ^ qi target - | GateDefn.RY {rot, target} => - "ry(" ^ Real.toString rot ^ ") " ^ qi target - | GateDefn.RX {rot, target} => - "rx(" ^ Real.toString rot ^ ") " ^ qi target - | GateDefn.CSwap {control, target1, target2} => - "cswap " ^ qi control ^ ", " ^ qi target1 ^ ", " ^ qi target2 - | GateDefn.Swap {target1, target2} => - "swap " ^ qi target1 ^ ", " ^ qi target2 - - | GateDefn.U {target, theta, phi, lambda} => - doOther - { name = "u" - , params = Seq.fromList [theta, phi, lambda] - , args = Seq.singleton target - } - - | GateDefn.Other {name, params, args} => - let - val pstr = - if Seq.length params = 0 then - "()" - else - "(" - ^ - Seq.iterate (fn (acc, e) => acc ^ ", " ^ Real.toString e) - (Real.toString (Seq.nth params 0)) (Seq.drop params 1) ^ ")" - - val front = name ^ pstr - - val args = - if Seq.length args = 0 then - "" - else - Seq.iterate (fn (acc, i) => acc ^ ", " ^ qi i) - (qi (Seq.nth args 0)) (Seq.drop args 1) - in - front ^ " " ^ args - end - in - Seq.iterate op^ header (Seq.map (fn g => gateToString g ^ ";\n") gates) - end - end diff --git a/feynsum-sml/src/common/DataFlowGraph.sml b/feynsum-sml/src/common/DataFlowGraph.sml new file mode 100644 index 0000000..54a95b1 --- /dev/null +++ b/feynsum-sml/src/common/DataFlowGraph.sml @@ -0,0 +1,181 @@ +structure DataFlowGraph: +sig + + type gate_idx = int + + type data_flow_graph = { + gates: GateDefn.t Seq.t, + preds: gate_idx Seq.t Seq.t, + succs: gate_idx Seq.t Seq.t, + numQubits: int + } + + type t = data_flow_graph + + val fromJSON: JSON.value -> data_flow_graph + val fromJSONString: string -> data_flow_graph + val fromJSONFile: string -> data_flow_graph + val fromQasm: Circuit.t -> data_flow_graph + + val toString: data_flow_graph -> string + (*val fromQasmString: string -> data_flow_graph*) + (*val fromQasmFile: string -> data_flow_graph*) + + (*val mkGateDefn: (string * real list option * int list) -> GateDefn.t*) + +end = +struct + + type gate_idx = int + + type data_flow_graph = { + gates: GateDefn.t Seq.t, + preds: gate_idx Seq.t Seq.t, + succs: gate_idx Seq.t Seq.t, + numQubits: int + } + + type t = data_flow_graph + + fun expect opt err = case opt of + NONE => raise Fail err + | SOME x => x + + fun mkGateDefn ("x", NONE, [q0]) = + GateDefn.X q0 + | mkGateDefn ("y", NONE, [q0]) = + GateDefn.PauliY q0 + | mkGateDefn ("z", NONE, [q0]) = + GateDefn.PauliZ q0 + | mkGateDefn ("h", NONE, [q0]) = + GateDefn.Hadamard q0 + | mkGateDefn ("sx", NONE, [q0]) = + GateDefn.SqrtX q0 + | mkGateDefn ("sxdg", NONE, [q0]) = + GateDefn.Sxdg q0 + | mkGateDefn ("s", NONE, [q0]) = + GateDefn.S q0 + | mkGateDefn ("sdg", NONE, [q0]) = + GateDefn.S q0 + | mkGateDefn ("t", NONE, [q0]) = + GateDefn.T q0 + | mkGateDefn ("tdg", NONE, [q0]) = + GateDefn.Tdg q0 + | mkGateDefn ("cx", NONE, [control, target]) = + GateDefn.CX {control = control, target = target} + | mkGateDefn ("cz", NONE, [control, target]) = + GateDefn.CZ {control = control, target = target} + | mkGateDefn ("ccx", NONE, [control1, control2, target]) = + GateDefn.CCX {control1 = control1, control2 = control2, target = target} + | mkGateDefn ("phase", SOME [rot], [target]) = + GateDefn.Phase {target = target, rot = rot} + | mkGateDefn ("cp", SOME [rot], [control, target]) = + GateDefn.CPhase {control = control, target = target, rot = rot} + | mkGateDefn ("fsim", SOME [theta, phi], [left, right]) = + GateDefn.FSim {left = left, right = right, theta = theta, phi = phi} + | mkGateDefn ("rx", SOME [rot], [target]) = + GateDefn.RX {rot = rot, target = target} + | mkGateDefn ("ry", SOME [rot], [target]) = + GateDefn.RY {rot = rot, target = target} + | mkGateDefn ("rz", SOME [rot], [target]) = + GateDefn.RZ {rot = rot, target = target} + | mkGateDefn ("swap", NONE, [target1, target2]) = + GateDefn.Swap {target1 = target1, target2 = target2} + | mkGateDefn ("cswap", NONE, [control, target1, target2]) = + GateDefn.CSwap {control = control, target1 = target1, target2 = target2} + | mkGateDefn ("u", SOME [theta, phi, lambda], [target]) = + GateDefn.U {target = target, theta = theta, phi = phi, lambda = lambda} + | mkGateDefn ("u2", SOME [phi, lambda], [target]) = + GateDefn.U {target = target, theta = Math.pi/2.0, phi = phi, lambda = lambda} + | mkGateDefn ("u1", SOME [lambda], [target]) = + GateDefn.U {target = target, theta = 0.0, phi = 0.0, lambda = lambda} + | mkGateDefn (name, params, qargs) = + raise Fail ("Unknown gate-params-qargs combination with name " ^ name) + (*fun mkGate (g) = G.fromGateDefn (mkGateDefn g)*) + + fun arrayToSeq a = Seq.tabulate (fn i => Array.sub (a, i)) (Array.length a) + + fun getSuccsPredsFromJSON (edges, N) = + let val preds = Array.array (N, nil); + val succs = Array.array (N, nil); + fun consAt (a, i, x) = Array.update (a, i, (x :: Array.sub (a, i))) + fun go nil = () + | go (JSON.ARRAY [JSON.INT fm, JSON.INT to] :: edges) = + let val fm64 = IntInf.toInt fm + val to64 = IntInf.toInt to + val _ = consAt (preds, to64, fm64) + val _ = consAt (succs, fm64, to64) + in + go edges + end + | go (_ :: edges) = raise Fail "Malformed edge in JSON" + val () = go edges; + val toSeq = Seq.map (Seq.rev o Seq.fromList) o arrayToSeq + in + (toSeq preds, toSeq succs) + end + + fun fromJSON (data) = + let fun to_gate g = + let val name = JSONUtil.asString (expect (JSONUtil.findField g "name") "Expected field 'name' in JSON"); + val params = Option.map (JSONUtil.arrayMap JSONUtil.asNumber) (JSONUtil.findField g "params"); + val qargs = JSONUtil.arrayMap JSONUtil.asInt (expect (JSONUtil.findField g "qargs") "Expected field 'qargs' in JSON"); + in + mkGateDefn (name, params, qargs) + end + val numqs = case JSONUtil.findField data "qubits" of + SOME (JSON.INT qs) => IntInf.toInt qs + | _ => raise Fail "Expected integer field 'qubits' in JSON" + val gates = case JSONUtil.findField data "nodes" of + SOME (JSON.ARRAY ns) => Seq.fromList (List.map to_gate ns) + | _ => raise Fail "Expected array field 'nodes' in JSON" + val edges = case JSONUtil.findField data "edges" of + SOME (JSON.ARRAY es) => es + | _ => raise Fail "Expected array field 'nodes' in JSON" + val (preds, succs) = getSuccsPredsFromJSON (edges, Seq.length gates) + in + { gates = gates, + preds = preds, + succs = succs, + numQubits = numqs } + (*Seq.zipWith (fn (g, (d, i)) => {gate = g, deps = d, indegree = i}) (gates, Seq.zip (deps, indegree))*) + end + fun fromJSONString (str) = fromJSON (JSONParser.parse (JSONParser.openString str)) + fun fromJSONFile (file) = fromJSON (JSONParser.parseFile file) + + (* TODO: convert to dependency graph *) + fun fromQasm {numQubits, gates} = + let val numGates = Seq.length gates + val qubitLastGate = Array.array (numQubits, ~1) + val preds = Array.array (numGates, nil) + val succs = Array.array (numGates, nil) + fun fillPreds gidx = + if gidx >= numGates then + () + else + let val gate = Seq.nth gates gidx + val args = GateDefn.getGateArgs gate + val lasts = List.filter (fn i => i >= 0) (List.map (fn qidx => Array.sub (qubitLastGate, qidx)) args) + val _ = Array.update (preds, gidx, lasts) + val _ = List.map (fn gidx' => Array.update (succs, gidx', gidx :: Array.sub (succs, gidx'))) lasts + val _ = List.map (fn qidx => Array.update (qubitLastGate, qidx, gidx)) args + in + fillPreds (gidx + 1) + end + val _ = fillPreds 0 + val predsSeq = Seq.map Seq.fromList (arrayToSeq preds) + val succsSeq = Seq.map (Seq.rev o Seq.fromList) (arrayToSeq succs) + in + { gates = gates, + preds = predsSeq, + succs = succsSeq, + numQubits = numQubits } + end + +fun toString {gates, preds, succs, numQubits} = + let val header = "qreg q[" ^ Int.toString numQubits ^ "];\n" + fun qi i = "q[" ^ Int.toString i ^ "]" + in + Seq.iterate op^ header (Seq.map (fn g => GateDefn.toString g qi ^ ";\n") gates) + end +end diff --git a/feynsum-sml/src/common/DataFlowGraphUtil.sml b/feynsum-sml/src/common/DataFlowGraphUtil.sml new file mode 100644 index 0000000..ffdb5da --- /dev/null +++ b/feynsum-sml/src/common/DataFlowGraphUtil.sml @@ -0,0 +1,179 @@ +structure DataFlowGraphUtil :> +sig + + type gate_idx = int + + (* Traversal Automaton State *) + type state = { visited: bool array, indegree: int array } + + type data_flow_graph = DataFlowGraph.t + + val visit: data_flow_graph -> gate_idx -> state -> unit + val frontier: state -> gate_idx Seq.t + val initState: data_flow_graph -> state + + (* Switches edge directions *) + val transpose: data_flow_graph -> data_flow_graph + + val scheduleWithOracle: data_flow_graph -> (gate_idx -> bool) -> (gate_idx Seq.t -> gate_idx) -> bool -> int -> gate_idx Seq.t Seq.t + + val scheduleWithOracle': data_flow_graph -> (gate_idx -> bool) -> ('state * gate_idx Seq.t -> gate_idx) -> bool -> int -> ('state * gate_idx Seq.t -> 'state) -> 'state -> 'state + + val scheduleCost: gate_idx Seq.t Seq.t -> (gate_idx -> bool) -> real + val chooseSchedule: gate_idx Seq.t Seq.t Seq.t -> (gate_idx -> bool) -> gate_idx Seq.t Seq.t + + val gateIsBranching: data_flow_graph -> (gate_idx -> bool) +end = +struct + + type gate_idx = int + + type data_flow_graph = DataFlowGraph.t + + fun transpose ({gates, preds, succs, numQubits}: data_flow_graph) = + { gates = gates, + preds = succs, + succs = preds, + numQubits = numQubits } + + (*fun transpose ({gates = gs, deps = ds, indegree = is, numQubits = qs}: data_flow_graph) = + let val N = Seq.length gs + val ds2 = Array.array (N, nil) + fun apply i = Seq.map (fn j => Array.update (ds2, j, i :: Array.sub (ds2, j))) (Seq.nth ds i) + val _ = Seq.tabulate apply N + in + {gates = gs, + deps = Seq.tabulate (fn i => Seq.rev (Seq.fromList (Array.sub (ds2, i)))) N, + indegree = Seq.map Seq.length ds, + numQubits = qs} + end*) + + type state = { visited: bool array, indegree: int array } + + fun visit {succs = succs, ...} i {visited = vis, indegree = deg} = + ( + (* Set visited[i] = true *) + Array.update (vis, i, true); + (* Decrement indegree of each i dependency *) + Seq.map (fn j => Array.update (deg, j, Array.sub (deg, j) - 1)) (Seq.nth succs i); + () + ) + + fun frontier {visited = vis, indegree = deg} = + let val N = Array.length vis + fun iter i acc = + if i < 0 then + acc + else + iter (i - 1) (if (not (Array.sub (vis, i))) andalso + Array.sub (deg, i) = 0 + then i :: acc else acc) + in + Seq.fromList (iter (N - 1) nil) + end + + fun initState (graph: data_flow_graph) = + let val N = Seq.length (#gates graph) + val vis = Array.array (N, false) + val deg = Array.tabulate (N, Seq.length o Seq.nth (#preds graph)) + in + { visited = vis, indegree = deg } + end + + fun scheduleWithOracle' (graph: data_flow_graph) (branching: gate_idx -> bool) (choose: 'state * gate_idx Seq.t -> gate_idx) (disableFusion: bool) (maxBranchingStride: int) (apply: 'state * gate_idx Seq.t -> 'state) state = + let val dgst = initState graph + fun findNonBranching (i: int) (xs: gate_idx Seq.t) = + if i = Seq.length xs then + NONE + else if branching (Seq.nth xs i) then + findNonBranching (i + 1) xs + else + SOME (Seq.nth xs i) + fun returnSeq thisStep acc = Seq.map Seq.fromList (Seq.rev (Seq.fromList (if List.null thisStep then acc else List.rev thisStep :: acc))) + fun loadNonBranching (acc: gate_idx list) = + case findNonBranching 0 (frontier dgst) of + NONE => acc + | SOME i => (visit graph i dgst; loadNonBranching (i :: acc)) + fun loadNext numBranchingSoFar thisStep state = + let val ftr = frontier dgst in + if Seq.length ftr = 0 then + state + else + (let val next = choose (state, ftr) in + visit graph next dgst; + if numBranchingSoFar + 1 >= maxBranchingStride then + loadNext 0 nil (apply (state, Seq.rev (Seq.fromList (loadNonBranching (next :: thisStep))))) + else + loadNext (numBranchingSoFar + 1) (loadNonBranching (next :: thisStep)) state + end) + end + fun loadNextNoFusion state = + let val ftr = frontier dgst in + if Seq.length ftr = 0 then + state + else + (let val next = choose (state, ftr) in + visit graph next dgst; + loadNextNoFusion (apply (state, Seq.singleton next)) + end) + end + in + if disableFusion then + loadNextNoFusion state + else + loadNext 0 (loadNonBranching nil) state + end + + fun scheduleWithOracle (graph: data_flow_graph) (branching: gate_idx -> bool) (choose: gate_idx Seq.t -> gate_idx) (disableFusion: bool) (maxBranchingStride: int) = Seq.rev (Seq.fromList (scheduleWithOracle' graph branching (fn (_, x) => choose x) disableFusion maxBranchingStride (fn (gs, g) => g :: gs) nil)) + + (*fun scheduleCost2 (order: gate_idx Seq.t Seq.t) (branching: gate_idx -> bool) = + let val gates = Seq.flatten order + val N = Seq.length gates + fun iter i cost branchedQubits = + if i = N then + cost + else + iter (i + 1) (1.0 + cost + (if branching (Seq.nth gates i) then cost else 0.0)) + in + iter 0 0.0 (Vector.tabulate (N, fn _ => false)) + end*) + + + fun scheduleCost (order: gate_idx Seq.t Seq.t) (branching: gate_idx -> bool) = + let val gates = Seq.flatten order + val N = Seq.length gates + fun iter i cost = + if i = N then + cost + else + iter (i + 1) (1.0 + (if branching (Seq.nth gates i) then cost * 1.67 else cost)) + in + iter 0 0.0 + end + + fun chooseSchedule (orders: gate_idx Seq.t Seq.t Seq.t) (branching: gate_idx -> bool) = + let fun iter i best_i best_cost = + if i = Seq.length orders then + Seq.nth orders best_i + else + let val cost = scheduleCost (Seq.nth orders i) branching in + if cost < best_cost then + (print ("Reduced cost from " ^ Real.toString best_cost ^ " to " ^ Real.toString cost ^ "\n"); iter (i + 1) i cost) + else + (print ("Maintained cost " ^ Real.toString best_cost ^ " over " ^ Real.toString cost ^ "\n"); iter (i + 1) best_i best_cost) + end + in + iter 1 0 (scheduleCost (Seq.nth orders 0) branching) + end + + (* B and C don't affect gate_branching, so pick arbitrarily *) + structure Gate_branching = Gate (structure B = BasisIdxUnlimited + structure C = Complex64) + + + fun gateIsBranching ({ gates = gates, ...} : data_flow_graph) = + let val branchSeq = Seq.map (fn g => Gate_branching.expectBranching (Gate_branching.fromGateDefn g)) gates in + fn i => Seq.nth branchSeq i + end + +end diff --git a/feynsum-sml/src/common/DynGateScheduler.sml b/feynsum-sml/src/common/DynGateScheduler.sml new file mode 100644 index 0000000..d63f4cb --- /dev/null +++ b/feynsum-sml/src/common/DynGateScheduler.sml @@ -0,0 +1,199 @@ +signature DYN_GATE_SCHEDULER = +sig + structure B: BASIS_IDX + structure C: COMPLEX + structure HS: HYBRID_STATE + sharing B = HS.B + sharing C = HS.C + + type gate_idx = int + + type t = DataFlowGraph.t -> (HS.t * gate_idx Seq.t -> gate_idx) + + val choose: t +end + +functor DynSchedFinishQubitWrapper + (structure B: BASIS_IDX + structure C: COMPLEX + structure HS: HYBRID_STATE + sharing B = HS.B + sharing C = HS.C + val maxBranchingStride: int + val disableFusion: bool + ): DYN_GATE_SCHEDULER = +struct + structure B = B + structure C = C + structure HS = HS + + type gate_idx = int + + type t = DataFlowGraph.t -> (HS.t * gate_idx Seq.t -> gate_idx) + + structure FQS = FinishQubitScheduler + (val maxBranchingStride = maxBranchingStride + val disableFusion = disableFusion) + + structure G = Gate + (structure B = B + structure C = C) + + fun choose (depgraph: DataFlowGraph.t) = + let val f = FQS.scheduler5 depgraph in + fn (_, gates) => f gates + end + +end + +functor DynSchedNaive + (structure B: BASIS_IDX + structure C: COMPLEX + structure HS: HYBRID_STATE + sharing B = HS.B + sharing C = HS.C + val maxBranchingStride: int + val disableFusion: bool + ): DYN_GATE_SCHEDULER = +struct + structure B = B + structure C = C + structure HS = HS + + type gate_idx = int + + type t = DataFlowGraph.t -> (HS.t * gate_idx Seq.t -> gate_idx) + + fun choose (depgraph: DataFlowGraph.t) = fn (_, gates) => Seq.nth gates 0 + +end + + +functor DynSchedInterference + (structure B: BASIS_IDX + structure C: COMPLEX + structure HS: HYBRID_STATE + sharing B = HS.B + sharing C = HS.C + val maxBranchingStride: int + val disableFusion: bool + ): DYN_GATE_SCHEDULER = +struct + structure B = B + structure C = C + structure HS = HS + + type gate_idx = int + + type t = DataFlowGraph.t -> (HS.t * gate_idx Seq.t -> gate_idx) + + structure FQS = FinishQubitScheduler + (val maxBranchingStride = maxBranchingStride + val disableFusion = disableFusion) + + structure G = Gate + (structure B = B + structure C = C) + + datatype Branched = Uninitialized | Zero | One | Superposition + + fun joinBranches' (br, b01) = + case (br, b01) of + (Uninitialized, false) => Zero + | (Uninitialized, true) => One + | (Zero, false) => Zero + | (Zero, true) => Superposition + | (One, false) => Superposition + | (One, true) => One + | (Superposition, _) => Superposition + + fun joinBranches (br, br') = + case (br, br') of + (Uninitialized, b) => b + | (a, Uninitialized) => a + | (Zero, Zero) => Zero + | (One, One) => One + | (_, _) => Superposition + + fun calculateBranchedQubits (numQubits, sst) = + let val nonZeros = DelayedSeq.mapOption (fn x => x) (HS.SST.unsafeViewContents sst) + val branchedQubits = Seq.tabulate (fn qi => DelayedSeq.reduce joinBranches Uninitialized (DelayedSeq.map (fn (b, c) => if B.get b qi then One else Zero) nonZeros)) numQubits + fun isbranched b = + case b of + Uninitialized => raise Fail "Uninitialized in isbranched! This shouldn't happen" + | Zero => false + | One => false + | Superposition => true + in + Seq.map isbranched branchedQubits + end + + fun choose (depgraph: DataFlowGraph.t) = + let val gates = Seq.map G.fromGateDefn (#gates depgraph) + val branchSeq = Seq.map G.expectBranching gates + fun branching i = Seq.nth branchSeq i + val numQubits = #numQubits depgraph + in + fn (hs, gidxs) => case hs of + HS.Dense d => Seq.nth gidxs 0 + | HS.DenseKnownNonZeroSize d => Seq.nth gidxs 0 + | HS.Sparse sst => + (* if dense, doesn't really matter what we pick? *) + let val branchedQubits = calculateBranchedQubits (numQubits, sst) + fun getBranching i = + Seq.reduce (fn ((br, nbr), (br', nbr')) => (br + br', nbr + nbr')) (0, 0) + (Seq.map (fn qi => if Seq.nth branchedQubits qi then + (1, 0) else (0, 1)) + (#touches (Seq.nth gates i))) + fun pick i (best_g, best_diff) = + if i >= Seq.length gidxs then + best_g + else + let val g = Seq.nth gidxs i + val (br, nbr) = getBranching g + val diff = if branching g then br - nbr else nbr - br + in + pick (i + 1) (if diff > best_diff then + (g, diff) else (best_g, best_diff)) + end + val (br0, nbr0) = getBranching (Seq.nth gidxs 0) + in + pick 1 (Seq.nth gidxs 0, if branching 0 then br0 - nbr0 else nbr0 - br0) + end + end + +end + + +(*functor DataFlowGraphDynScheduler + (structure B: BASIS_IDX + structure C: COMPLEX + structure SST: SPARSE_STATE_TABLE + structure DS: DENSE_STATE + structure G: GATE + sharing B = SST.B = DS.B = G.B + sharing C = SST.C = DS.C = G.C + val blockSize: int + val maxload: real + val denseThreshold: real + val pullThreshold: real): DYN_GATE_SCHEDULER = +struct + type gate_idx = int + structure Expander = + ExpandState + (structure B = B + structure C = C + structure SST = SST + structure DS = DS + structure G = G + val denseThreshold = denseThreshold + val blockSize = blockSize + val maxload = maxload + val pullThreshold = pullThreshold) + + ( * From a frontier, select which gate to apply next * ) + ( * args visit gates, update frontier break fusion initial frontier gate batches * ) + ( * type t = args -> (gate_idx -> gate_idx Seq.t) -> (unit -> unit) -> gate_idx Seq.t -> gate_idx Seq.t Seq.t* ) + type t = DataFlowGraph.t -> (Expander.state * gate_idx Seq.t -> gate_idx) +end +*) diff --git a/feynsum-sml/src/common/GateDefn.sml b/feynsum-sml/src/common/GateDefn.sml index a2f7928..4387268 100644 --- a/feynsum-sml/src/common/GateDefn.sml +++ b/feynsum-sml/src/common/GateDefn.sml @@ -41,4 +41,80 @@ struct type gate = t + fun getGateArgs (g: gate) = case g of + PauliY i => [i] + | PauliZ i => [i] + | Hadamard i => [i] + | SqrtX i => [i] + | Sxdg i => [i] + | S i => [i] + | Sdg i => [i] + | X i => [i] + | T i => [i] + | Tdg i => [i] + | CX {control = i, target = j} => [i, j] + | CZ {control = i, target = j} => [i, j] + | CCX {control1 = i, control2 = j, target = k} => [i, j, k] + | Phase {target = i, ...} => [i] + | CPhase {control = i, target = j, ...} => [i, j] + | FSim {left = i, right = j, ...} => [i, j] + | RZ {target = i, ...} => [i] + | RY {target = i, ...} => [i] + | RX {target = i, ...} => [i] + | Swap {target1 = i, target2 = j} => [i, j] + | CSwap {control = i, target1 = j, target2 = k} => [i, j, k] + | U {target = i, ...} => [i] + | Other {args = args, ...} => Seq.toList args + + fun toString (g: gate) (qi: qubit_idx -> string) = case g of + PauliY i => "y " ^ qi i + | PauliZ i => "z " ^ qi i + | Hadamard i => "h " ^ qi i + | T i => "t " ^ qi i + | Tdg i => "tdg " ^ qi i + | SqrtX i => "sx " ^ qi i + | Sxdg i => "sxdg " ^ qi i + | S i => "s " ^ qi i + | Sdg i => "sdg " ^ qi i + | X i => "x " ^ qi i + | CX {control, target} => "cx " ^ qi control ^ ", " ^ qi target + | CZ {control, target} => "cz " ^ qi control ^ ", " ^ qi target + | CCX {control1, control2, target} => + "ccx " ^ qi control1 ^ ", " ^ qi control2 ^ ", " ^ qi target + | Phase {target, rot} => + "phase(" ^ Real.toString rot ^ ") " ^ qi target + | CPhase {control, target, rot} => + "cphase(" ^ Real.toString rot ^ ") " ^ qi control ^ ", " ^ qi target + | FSim {left, right, theta, phi} => + "fsim(" ^ Real.toString theta ^ ", " ^ Real.toString phi ^ ") " + ^ qi left ^ ", " ^ qi right + | RZ {rot, target} => + "rz(" ^ Real.toString rot ^ ") " ^ qi target + | RY {rot, target} => + "ry(" ^ Real.toString rot ^ ") " ^ qi target + | RX {rot, target} => + "rx(" ^ Real.toString rot ^ ") " ^ qi target + | CSwap {control, target1, target2} => + "cswap " ^ qi control ^ ", " ^ qi target1 ^ ", " ^ qi target2 + | Swap {target1, target2} => + "swap " ^ qi target1 ^ ", " ^ qi target2 + | U {target, theta, phi, lambda} => + "u(" ^ Real.toString theta ^ ", " ^ Real.toString phi ^ ", " ^ Real.toString lambda ^ ") " ^ qi target + | Other {name, params, args} => + let val pstr = + if Seq.length params = 0 then + "()" + else + "(" ^ Seq.iterate (fn (acc, e) => acc ^ ", " ^ Real.toString e) + (Real.toString (Seq.nth params 0)) (Seq.drop params 1) ^ ")" + val front = name ^ pstr + val args = + if Seq.length args = 0 then + "" + else + Seq.iterate (fn (acc, i) => acc ^ ", " ^ qi i) + (qi (Seq.nth args 0)) (Seq.drop args 1) + in + front ^ " " ^ args + end end diff --git a/feynsum-sml/src/common/GateScheduler.sml b/feynsum-sml/src/common/GateScheduler.sml index 9009653..b07bc03 100644 --- a/feynsum-sml/src/common/GateScheduler.sml +++ b/feynsum-sml/src/common/GateScheduler.sml @@ -1,16 +1,10 @@ structure GateScheduler = struct - type qubit_idx = int type gate_idx = int - type args = - { numQubits: int - , numGates: int - , gateTouches: gate_idx -> qubit_idx Seq.t - , gateIsBranching: gate_idx -> bool - } - - type t = args -> (unit -> gate_idx Seq.t) - + (* From a frontier, select which gate to apply next *) + (* args visit gates, update frontier break fusion initial frontier gate batches *) + (*type t = args -> (gate_idx -> gate_idx Seq.t) -> (unit -> unit) -> gate_idx Seq.t -> gate_idx Seq.t Seq.t*) + type t = DataFlowGraph.t -> (gate_idx Seq.t -> gate_idx) end diff --git a/feynsum-sml/src/common/GateSchedulerGreedyBranching.sml b/feynsum-sml/src/common/GateSchedulerGreedyBranching.sml deleted file mode 100644 index c23c303..0000000 --- a/feynsum-sml/src/common/GateSchedulerGreedyBranching.sml +++ /dev/null @@ -1,157 +0,0 @@ -structure GateSchedulerGreedyBranching: -sig - val scheduler: GateScheduler.t -end = -struct - - type qubit_idx = int - type gate_idx = int - - - datatype sched = - S of - { numQubits: int - , numGates: int - , gateTouches: gate_idx -> qubit_idx Seq.t - , gateIsBranching: gate_idx -> bool - (* each qubit keeps track of which gate is next *) - , frontier: gate_idx array - } - - - type t = sched - - - fun contains x xs = - Util.exists (0, Seq.length xs) (fn i => Seq.nth xs i = x) - - - fun nextTouch (xxx as {numGates, gateTouches}) qubit gidx = - if gidx >= numGates then numGates - else if contains qubit (gateTouches gidx) then gidx - else nextTouch xxx qubit (gidx + 1) - - - fun new {numQubits, numGates, gateTouches, gateIsBranching} = - S { numQubits = numQubits - , numGates = numGates - , gateTouches = gateTouches - , gateIsBranching = gateIsBranching - , frontier = SeqBasis.tabulate 100 (0, numQubits) (fn i => - nextTouch {numGates = numGates, gateTouches = gateTouches} i 0) - } - - - (* It's safe to visit a gate G if, for all qubits the gate touches, the - * qubit's next gate is G. - *) - fun okayToVisit - (S {numQubits, numGates, gateTouches, gateIsBranching, frontier}) gidx = - if gidx >= numGates then - false - else - let - val touches = gateTouches gidx - in - Util.all (0, Seq.length touches) (fn i => - Array.sub (frontier, Seq.nth touches i) = gidx) - end - - - fun tryVisit - (sched as S {numQubits, numGates, gateTouches, gateIsBranching, frontier}) - gidx = - if not (okayToVisit sched gidx) then - false - else - let - (* val _ = print - ("GateScheduler.tryVisit " ^ Int.toString gidx ^ " numGates " - ^ Int.toString numGates ^ " frontier " - ^ Seq.toString Int.toString (ArraySlice.full frontier) ^ "\n") *) - val touches = gateTouches gidx - in - ( Util.for (0, Seq.length touches) (fn i => - let - val qi = Seq.nth touches i - val next = - nextTouch {numGates = numGates, gateTouches = gateTouches} qi - (gidx + 1) - in - Array.update (frontier, qi, next) - end) - - ; true - ) - end - - - fun visitBranching - (sched as S {numQubits, numGates, gateTouches, gateIsBranching, frontier}) = - let - val possibles = - Seq.filter - (fn gi => - gi < numGates andalso gateIsBranching gi - andalso okayToVisit sched gi) (ArraySlice.full frontier) - in - if Seq.length possibles = 0 then - NONE - else - let - val gidx = Seq.nth possibles 0 - in - if tryVisit sched gidx then () - else raise Fail "GateSchedulerGreedyBranching.visitBranching: error"; - - SOME gidx - end - end - - - fun visitNonBranching - (sched as S {numQubits, numGates, gateTouches, gateIsBranching, frontier}) = - let - (* val _ = print - ("visitNonBranching frontier = " - ^ Seq.toString Int.toString (ArraySlice.full frontier) ^ "\n") *) - val possibles = - Seq.filter - (fn gi => - gi < numGates andalso not (gateIsBranching gi) - andalso okayToVisit sched gi) (ArraySlice.full frontier) - in - if Seq.length possibles = 0 then - NONE - else - let - val gidx = Seq.nth possibles 0 - in - (* print - ("visitNonBranching trying to visit " ^ Int.toString gidx ^ "\n"); *) - - if tryVisit sched gidx then - () - else - raise Fail "GateSchedulerGreedyBranching.visitNonBranching: error"; - - SOME gidx - end - end - - - fun pickNext sched = - case visitBranching sched of - SOME gidx => Seq.singleton gidx - | NONE => - case visitNonBranching sched of - SOME gidx => Seq.singleton gidx - | NONE => Seq.empty () - - - fun scheduler args = - let val sched = new args - in fn () => pickNext sched - end - -end diff --git a/feynsum-sml/src/common/GateSchedulerGreedyFinishQubit.sml b/feynsum-sml/src/common/GateSchedulerGreedyFinishQubit.sml deleted file mode 100644 index 7f1b579..0000000 --- a/feynsum-sml/src/common/GateSchedulerGreedyFinishQubit.sml +++ /dev/null @@ -1,178 +0,0 @@ -functor GateSchedulerGreedyFinishQubit - (val maxBranchingStride: int val disableFusion: bool): -sig - val scheduler: GateScheduler.t -end = -struct - - type qubit_idx = int - type gate_idx = int - - - datatype sched = - S of - { numQubits: int - , numGates: int - , gateTouches: gate_idx -> qubit_idx Seq.t - , gateIsBranching: gate_idx -> bool - (* each qubit keeps track of which gate is next *) - , frontier: gate_idx array - } - - - type t = sched - - - fun contains x xs = - Util.exists (0, Seq.length xs) (fn i => Seq.nth xs i = x) - - - fun nextTouch (xxx as {numGates, gateTouches}) qubit gidx = - if gidx >= numGates then numGates - else if contains qubit (gateTouches gidx) then gidx - else nextTouch xxx qubit (gidx + 1) - - - fun new {numQubits, numGates, gateTouches, gateIsBranching} = - S { numQubits = numQubits - , numGates = numGates - , gateTouches = gateTouches - , gateIsBranching = gateIsBranching - , frontier = SeqBasis.tabulate 100 (0, numQubits) (fn i => - nextTouch {numGates = numGates, gateTouches = gateTouches} i 0) - } - - - (* It's safe to visit a gate G if, for all qubits the gate touches, the - * qubit's next gate is G. - *) - fun okayToVisit - (S {numQubits, numGates, gateTouches, gateIsBranching, frontier}) gidx = - if gidx >= numGates then - false - else - let - val touches = gateTouches gidx - in - Util.all (0, Seq.length touches) (fn i => - Array.sub (frontier, Seq.nth touches i) = gidx) - end - - - fun tryVisit - (sched as S {numQubits, numGates, gateTouches, gateIsBranching, frontier}) - gidx = - if not (okayToVisit sched gidx) then - false - else - let - (* val _ = print - ("GateScheduler.tryVisit " ^ Int.toString gidx ^ " numGates " - ^ Int.toString numGates ^ " frontier " - ^ Seq.toString Int.toString (ArraySlice.full frontier) ^ "\n") *) - val touches = gateTouches gidx - in - ( Util.for (0, Seq.length touches) (fn i => - let - val qi = Seq.nth touches i - val next = - nextTouch {numGates = numGates, gateTouches = gateTouches} qi - (gidx + 1) - in - Array.update (frontier, qi, next) - end) - - ; true - ) - end - - - fun peekNext - (sched as S {numQubits, numGates, gateTouches, gateIsBranching, frontier}) = - let - fun makeProgressOnQubit qi = - let - val desiredGate = Array.sub (frontier, qi) - in - if okayToVisit sched desiredGate then - desiredGate - else - (* Find which qubit is not ready to do this gate yet, and make - * progress on that qubit (which might recursively need to - * make progress on a different qubit, etc.) - *) - let - val touches = gateTouches desiredGate - val dependency = - FindFirst.findFirstSerial (0, Seq.length touches) (fn i => - let val qj = Seq.nth touches i - in Array.sub (frontier, qj) < desiredGate - end) - in - case dependency of - SOME i => makeProgressOnQubit (Seq.nth touches i) - | NONE => - raise Fail - "GateSchedulerGreedyFinishQubit.peekNext.makeProgressOnQubit: error" - end - end - - val unfinishedQubit = FindFirst.findFirstSerial (0, numQubits) (fn qi => - Array.sub (frontier, qi) < numGates) - in - case unfinishedQubit of - NONE => NONE - | SOME qi => SOME (makeProgressOnQubit qi) - end - - - fun pickNextNoFusion sched = - case peekNext sched of - NONE => Seq.empty () - | SOME gidx => - ( if tryVisit sched gidx then - () - else - raise Fail - "GateSchedulerGreedyFinishQubit.pickNextNoFusion: visit failed (should be impossible)" - ; Seq.singleton gidx - ) - - - fun pickNext (sched as S {gateIsBranching, ...}) = - if disableFusion then - pickNextNoFusion sched - else - let - fun loop acc numBranchingSoFar = - if numBranchingSoFar >= maxBranchingStride then - acc - else - case peekNext sched of - NONE => acc - | SOME gidx => - let - val numBranchingSoFar' = - numBranchingSoFar + (if gateIsBranching gidx then 1 else 0) - in - if tryVisit sched gidx then - () - else - raise Fail - "GateSchedulerGreedyFinishQubit.pickNext.loop: visit failed (should be impossible)"; - - loop (gidx :: acc) numBranchingSoFar' - end - - val acc = loop [] 0 - in - Seq.fromRevList acc - end - - - fun scheduler args = - let val sched = new args - in fn () => pickNext sched - end - -end diff --git a/feynsum-sml/src/common/GateSchedulerGreedyNonBranching.sml b/feynsum-sml/src/common/GateSchedulerGreedyNonBranching.sml deleted file mode 100644 index 6b2e80a..0000000 --- a/feynsum-sml/src/common/GateSchedulerGreedyNonBranching.sml +++ /dev/null @@ -1,228 +0,0 @@ -functor GateSchedulerGreedyNonBranching - (val maxBranchingStride: int val disableFusion: bool): -sig - type sched - type t = sched - - type qubit_idx = int - type gate_idx = int - - val new: GateScheduler.args -> sched - - val tryVisit: sched -> gate_idx -> bool - val visitMaximalNonBranchingRun: sched -> gate_idx Seq.t - val visitBranching: sched -> gate_idx option - val visitNonBranching: sched -> gate_idx option - val pickNext: sched -> gate_idx Seq.t - - val scheduler: GateScheduler.t -end = -struct - - type qubit_idx = int - type gate_idx = int - - - datatype sched = - S of - { numQubits: int - , numGates: int - , gateTouches: gate_idx -> qubit_idx Seq.t - , gateIsBranching: gate_idx -> bool - (* each qubit keeps track of which gate is next *) - , frontier: gate_idx array - } - - - type t = sched - - - fun contains x xs = - Util.exists (0, Seq.length xs) (fn i => Seq.nth xs i = x) - - - fun nextTouch (xxx as {numGates, gateTouches}) qubit gidx = - if gidx >= numGates then numGates - else if contains qubit (gateTouches gidx) then gidx - else nextTouch xxx qubit (gidx + 1) - - - fun new {numQubits, numGates, gateTouches, gateIsBranching} = - S { numQubits = numQubits - , numGates = numGates - , gateTouches = gateTouches - , gateIsBranching = gateIsBranching - , frontier = SeqBasis.tabulate 100 (0, numQubits) (fn i => - nextTouch {numGates = numGates, gateTouches = gateTouches} i 0) - } - - - fun okayToVisit - (S {numQubits, numGates, gateTouches, gateIsBranching, frontier}) gidx = - if gidx >= numGates then - false - else - let - val touches = gateTouches gidx - in - Util.all (0, Seq.length touches) (fn i => - Array.sub (frontier, Seq.nth touches i) = gidx) - end - - - (* It's safe to visit a gate G if, for all qubits the gate touches, the - * qubit's next gate is G. - *) - fun tryVisit - (sched as S {numQubits, numGates, gateTouches, gateIsBranching, frontier}) - gidx = - if not (okayToVisit sched gidx) then - false - else - let - (* val _ = print - ("GateScheduler.tryVisit " ^ Int.toString gidx ^ " numGates " - ^ Int.toString numGates ^ " frontier " - ^ Seq.toString Int.toString (ArraySlice.full frontier) ^ "\n") *) - val touches = gateTouches gidx - in - ( Util.for (0, Seq.length touches) (fn i => - let - val qi = Seq.nth touches i - val next = - nextTouch {numGates = numGates, gateTouches = gateTouches} qi - (gidx + 1) - in - Array.update (frontier, qi, next) - end) - - ; true - ) - end - - - fun visitMaximalNonBranchingRun - (sched as S {numQubits, numGates, gateTouches, gateIsBranching, frontier}) = - let - (* visit as many non-branching gates on qubit qi as possible *) - fun loopQubit acc qi = - let - val nextg = Array.sub (frontier, qi) - in - if - nextg >= numGates orelse gateIsBranching nextg - orelse not (tryVisit sched nextg) - then acc - else loopQubit (nextg :: acc) qi - end - - fun loop acc = - let - val selection = Util.loop (0, numQubits) [] (fn (acc, qi) => - loopQubit acc qi) - in - case selection of - [] => acc - | _ => loop (selection @ acc) - end - in - Seq.fromRevList (loop []) - end - - - fun visitBranching - (sched as S {numQubits, numGates, gateTouches, gateIsBranching, frontier}) = - let - val possibles = - Seq.filter - (fn gi => - gi < numGates andalso gateIsBranching gi - andalso okayToVisit sched gi) (ArraySlice.full frontier) - in - if Seq.length possibles = 0 then - NONE - else - let - val gidx = Seq.nth possibles 0 - in - if tryVisit sched gidx then - () - else - raise Fail "GateSchedulerGreedyNonBranching.visitBranching: error"; - - SOME gidx - end - end - - - fun visitNonBranching - (sched as S {numQubits, numGates, gateTouches, gateIsBranching, frontier}) = - let - (* val _ = print - ("visitNonBranching frontier = " - ^ Seq.toString Int.toString (ArraySlice.full frontier) ^ "\n") *) - val possibles = - Seq.filter - (fn gi => - gi < numGates andalso not (gateIsBranching gi) - andalso okayToVisit sched gi) (ArraySlice.full frontier) - in - if Seq.length possibles = 0 then - NONE - else - let - val gidx = Seq.nth possibles 0 - in - (* print - ("visitNonBranching trying to visit " ^ Int.toString gidx ^ "\n"); *) - - if tryVisit sched gidx then - () - else - raise Fail - "GateSchedulerGreedyNonBranching.visitNonBranching: error"; - - SOME gidx - end - end - - - fun pickNextNoFusion sched = - case visitNonBranching sched of - SOME gidx => Seq.singleton gidx - | NONE => - case visitBranching sched of - SOME gidx => Seq.singleton gidx - | NONE => Seq.empty () - - - fun pickNext sched = - if disableFusion then - pickNextNoFusion sched - else - let - fun loop acc numBranchingSoFar = - if numBranchingSoFar >= maxBranchingStride then - acc - else - let - val nb = visitMaximalNonBranchingRun sched - in - case visitBranching sched of - NONE => nb :: acc - | SOME gidx => - loop (Seq.singleton gidx :: nb :: acc) (numBranchingSoFar + 1) - end - - val acc = loop [] 0 - in - Seq.flatten (Seq.fromRevList acc) - end - - - fun scheduler args = - let val sched = new args - in fn () => pickNext sched - end - -end diff --git a/feynsum-sml/src/common/GateSchedulerNaive.sml b/feynsum-sml/src/common/GateSchedulerNaive.sml deleted file mode 100644 index 917c476..0000000 --- a/feynsum-sml/src/common/GateSchedulerNaive.sml +++ /dev/null @@ -1,34 +0,0 @@ -(* A "naive" scheduler: execute in straightline order, i.e., as the gates are - * written in the input qasm. No fusion. - *) -structure GateSchedulerNaive: -sig - val scheduler: GateScheduler.t -end = -struct - - type qubit_idx = int - type gate_idx = int - - datatype sched = S of {numGates: int, next: int ref} - - type t = sched - - fun new - { numQubits: int - , numGates: int - , gateTouches: gate_idx -> qubit_idx Seq.t - , gateIsBranching: gate_idx -> bool - } = - S {numGates = numGates, next = ref 0} - - fun pickNext (S {numGates, next, ...}) = - if !next >= numGates then Seq.empty () - else let val gi = !next in next := gi + 1; Seq.singleton gi end - - fun scheduler args = - let val sched = new args - in fn () => pickNext sched - end - -end diff --git a/feynsum-sml/src/common/HashSet.sml b/feynsum-sml/src/common/HashSet.sml new file mode 100644 index 0000000..05c0546 --- /dev/null +++ b/feynsum-sml/src/common/HashSet.sml @@ -0,0 +1,138 @@ +structure HashSet :> +sig + type 'a t + type 'a table = 'a t + + exception Full + + val make: {hash: 'a -> int, eq: 'a * 'a -> bool, capacity: int, maxload: real} -> 'a table + val size: 'a table -> int + val capacity: 'a table -> int + val resize: 'a table -> 'a table + val increaseCapacityTo: int -> 'a table -> 'a table + val insert: 'a table -> 'a -> unit + val lookup: 'a table -> 'a -> bool + val compact: 'a table -> 'a Seq.t + + (* Unsafe because underlying array is shared. If the table is mutated, + * then the Seq would not appear to be immutable. + * + * Could also imagine a function `freezeViewContents` which marks the + * table as immutable (preventing further inserts). That would be a safer + * version of this function. + *) + val unsafeViewContents: 'a table -> 'a option Seq.t +end = +struct + + datatype 'a t = + T of + { data: 'a option array + , hash: 'a -> int + , eq: 'a * 'a -> bool + , maxload: real + } + + exception Full + + type 'a table = 'a t + + + fun make {hash, eq, capacity, maxload} = + if capacity = 0 then + raise Fail "HashTable.make: capacity 0" + else + let val data = SeqBasis.tabulate 5000 (0, capacity) (fn _ => NONE) + in T {data = data, hash = hash, eq = eq, maxload = maxload} + end + + + fun unsafeViewContents (T {data, ...}) = ArraySlice.full data + + + fun bcas (arr, i) (old, new) = + MLton.eq (old, Concurrency.casArray (arr, i) (old, new)) + + + fun size (T {data, ...}) = + SeqBasis.reduce 10000 op+ 0 (0, Array.length data) (fn i => + if Option.isSome (Array.sub (data, i)) then 1 else 0) + + + fun capacity (T {data, ...}) = Array.length data + + + fun insert' (T {data, hash, eq, maxload}) x force = + let + val n = Array.length data + val tolerance = 20 * Real.ceil (1.0 / (1.0 - maxload)) + + fun loop i probes = + if not force andalso probes >= tolerance then + raise Full + else if i >= n then + loop 0 probes + else + let val current = Array.sub (data, i) in + case current of + SOME y => + if eq (x, y) then () else loop (i + 1) (probes + 1) + | NONE => + if bcas (data, i) (NONE, SOME x) then + () + else + loop i probes + end + + val start = (hash x) mod (Array.length data) + in + loop start 0 + end + + + fun insert s x = insert' s x false + + + fun lookup (T {data, hash, eq, ...}) x = + let + val n = Array.length data + val start = (hash x) mod n + + fun loop i = + case Array.sub (data, i) of + NONE => false + | SOME y => eq (x, y) orelse loopCheck (i + 1) + + and loopCheck i = + if i >= n then loopCheck 0 else (i <> start andalso loop i) + in + n <> 0 andalso loop start + end + + + fun increaseCapacityTo newcap (input as T {data, hash, eq, maxload}) = + if newcap < capacity input then + raise Fail "HashTable.increaseCapacityTo: new cap is too small" + else + let + val new = make + {hash = hash, eq = eq, capacity = newcap, maxload = maxload} + in + ForkJoin.parfor 1000 (0, Array.length data) (fn i => + case Array.sub (data, i) of + NONE => () + | SOME x => (insert' new x true; ())); + + new + end + + + fun resize x = + increaseCapacityTo (2 * capacity x) x + + + fun compact (T {data, ...}) = + ArraySlice.full (SeqBasis.tabFilter 2000 (0, Array.length data) (fn i => + Array.sub (data, i))) + +end diff --git a/feynsum-sml/src/common/BASIS_IDX.sml b/feynsum-sml/src/common/basis/BASIS_IDX.sml similarity index 100% rename from feynsum-sml/src/common/BASIS_IDX.sml rename to feynsum-sml/src/common/basis/BASIS_IDX.sml diff --git a/feynsum-sml/src/common/BasisIdx64.sml b/feynsum-sml/src/common/basis/BasisIdx64.sml similarity index 100% rename from feynsum-sml/src/common/BasisIdx64.sml rename to feynsum-sml/src/common/basis/BasisIdx64.sml diff --git a/feynsum-sml/src/common/BasisIdxUnlimited.sml b/feynsum-sml/src/common/basis/BasisIdxUnlimited.sml similarity index 100% rename from feynsum-sml/src/common/BasisIdxUnlimited.sml rename to feynsum-sml/src/common/basis/BasisIdxUnlimited.sml diff --git a/feynsum-sml/src/common/Complex.sml b/feynsum-sml/src/common/complex/Complex.sml similarity index 100% rename from feynsum-sml/src/common/Complex.sml rename to feynsum-sml/src/common/complex/Complex.sml diff --git a/feynsum-sml/src/common/Complex32.sml b/feynsum-sml/src/common/complex/Complex32.sml similarity index 100% rename from feynsum-sml/src/common/Complex32.sml rename to feynsum-sml/src/common/complex/Complex32.sml diff --git a/feynsum-sml/src/common/Complex64.sml b/feynsum-sml/src/common/complex/Complex64.sml similarity index 100% rename from feynsum-sml/src/common/Complex64.sml rename to feynsum-sml/src/common/complex/Complex64.sml diff --git a/feynsum-sml/src/common/MkComplex.sml b/feynsum-sml/src/common/complex/MkComplex.sml similarity index 100% rename from feynsum-sml/src/common/MkComplex.sml rename to feynsum-sml/src/common/complex/MkComplex.sml diff --git a/feynsum-sml/src/common/depgraph.py b/feynsum-sml/src/common/depgraph.py new file mode 100644 index 0000000..d06e669 --- /dev/null +++ b/feynsum-sml/src/common/depgraph.py @@ -0,0 +1,332 @@ +from numpy.typing import NDArray +import qiskit as Q +from qiskit.transpiler.passes import RemoveBarriers +import rustworkx as rx +import numpy as np +import sys +import json +from typing import Any, Iterator, Optional +import time + +#import scipy +#def cluster(dag): +# adjmtx = rx.adjacency_matrix(dag) +# adjid = np.identity(adjmtx.shape[0]) +# whitened = scipy.cluster.vq.whiten(adjmtx + adjid) +# return scipy.cluster.vq.kmeans2(whitened, adjmtx.shape[0]//2, minit='points')[1] + +EPS = 1e-10 + +def circuit_to_dag(circ: Q.circuit.QuantumCircuit): + return Q.converters.circuit_to_dag(circ) + +def circuit_to_dep_graph(circ: Q.circuit.QuantumCircuit) -> Q.dagcircuit.DAGCircuit: + dag = circuit_to_dag(circ) + dag._multi_graph = dag_to_dep_graph(dag._multi_graph, circuit_find_qubit_dict(circ)) + return dag + +def circuit_find_qubit_dict(circ: Q.circuit.QuantumCircuit) -> dict[Q.circuit.Bit, int]: + return {x: circ.find_bit(x)[0] for x in circ.qubits} + +def indirect_reachability(dag): + "Returns a mapping from nodes to sets of nodes reachable from them, via 2 or more edges" + reaches = dict() + def visit(node): + nodes = {node} + reaches[node] = nodes + outs = dag.out_edges(node) + for _, to, _ in outs: + if to not in reaches: + visit(to) + nodes |= reaches[to] | ({i for _, i, _ in dag.out_edges(to)} - set(outs)) + for node in dag.node_indices(): + visit(node) + return reaches + +def trim_edges(dag): + todo = set(dag.edge_list()) + indirect_reaches = indirect_reachability(dag) + + def get_id(node): + if '_node_id' in dir(node): + return node._node_id + elif 'node_id' in dir(node): + return node.node_id + + while todo: + edge = todo.pop() + gai, gbi = edge + if any(gbi in indirect_reaches[get_id(gci)] and gbi != get_id(gci) for gci in dag.successors(gai)): + dag.remove_edge(gai, gbi) + +def is_op_node(node: Q.dagcircuit.DAGNode) -> bool: + return isinstance(node, Q.dagcircuit.DAGOpNode) + +def dag_to_dep_graph(dag: Q.dagcircuit.DAGCircuit, + find_qubit: dict[Q.circuit.Bit, int], + trim=False) -> rx.PyDAG: + assert isinstance(dag, rx.PyDAG), \ + f"dag_to_dep_graph expected a DAG of type rustworkx.PyDAG, but got {type(dag)} instead" + graph = rx.PyDAG(multigraph=False) + graph.add_nodes_from(dag.nodes()) + reaches = {i: {i} for i in dag.node_indices()} + todo = set() + for gai, gbi, qi in dag.weighted_edge_list(): + graph.add_edge(gai, gbi, None) + if is_op_node(graph.get_node_data(gai)) and is_op_node(graph.get_node_data(gbi)): + todo.add((gai, gbi)) + visited = set() + + def maybe_push_edge(ifm, ito, gfm, gto): + """ + Pushes an edge to `todo` if we need to visit it + Args: + ifm = index of source gate + ito = index of target gate + gfm = source gate + gto = target gate + """ + + opfm = is_op_node(gfm) + opto = is_op_node(gto) + + if opfm and opto: + shared = any(find_qubit[q1] == find_qubit[q2] + for q1 in gfm.qargs for q2 in gto.qargs) + if shared and (ifm, ito) not in todo \ + and (ifm, ito) not in visited \ + and (not trim or ito not in reaches[ifm]): + graph.add_edge(ifm, ito, None) + todo.add((ifm, ito)) + elif (opfm and any(find_qubit[q] == find_qubit[gto.wire] for q in gfm.qargs)) or \ + (opto and any(find_qubit[q] == find_qubit[gfm.wire] for q in gto.qargs)) or \ + (not opfm and not opto and gfm.wire == gto.wire): + graph.add_edge(ifm, ito, None) + + while todo: + gai, gbi = todo.pop() + visited.add((gai, gbi)) + ga = graph.get_node_data(gai) + gb = graph.get_node_data(gbi) + if commutes(ga, gb, find_qubit): + graph.remove_edge(gai, gbi) + # Propagate in-edges + for pidx, _, _ in graph.in_edges(gai): + maybe_push_edge(pidx, gbi, graph.get_node_data(pidx), gb) + # Propagate out-edges + for _, cidx, _ in graph.out_edges(gbi): + maybe_push_edge(gai, cidx, ga, graph.get_node_data(cidx)) + elif trim: + reaches[gai] |= reaches[gbi] + return graph + +def as_bits(n: int, nbits: Optional[int] = None) -> list[int]: + "Converts an integer into its bit representation" + return [int(bool(n & (1 << (i - 1)))) for i in range(nbits or n.bit_length(), 0, -1)] + +def rearrange_gate(mat: NDArray, old: list[int], new: list[int]) -> NDArray: + """ + Rearranges a gate's unitary matrix for application to a new set of qubits. + Assumes set(old) = set(new). + """ + old = list(reversed(old)) + new = list(reversed(new)) + qubits = len(new) + old_idx = {o:i for i, o in enumerate(old)} + new_idx = {n:i for i, n in enumerate(new)} + old2new_idx = {o:new_idx[o] for o in old} + new2old_idx = {n:old_idx[n] for n in new} + + size = 1 << qubits + I = np.identity(size, dtype=np.dtype('int64')) + + bit_map = np.array([I[new2old_idx[n]] for n in new]) + bit_mat = np.array([as_bits(i, qubits) for i in range(size)]) + mapped = bit_mat @ bit_map + for i in range(qubits - 1): + mapped[:, i] <<= qubits - i - 1 + reordered = mapped.sum(axis=1) + + mat1 = np.ndarray(mat.shape, dtype=mat.dtype) + mat2 = np.ndarray(mat.shape, dtype=mat.dtype) + + for i in range(size): + mat1[i, :] = mat[reordered[i], :] + for i in range(size): + mat2[:, i] = mat1[:, reordered[i]] + return mat2 + +def align_gates(ga, gb, find_qubit: dict[Q.circuit.Bit, int]): + """ + Aligns gates along the same qubits, returning a tuple of their new unitaries + """ + ma = ga.op.to_matrix() + mb = gb.op.to_matrix() + qas = [find_qubit[qi] for qi in ga.qargs] + qbs = [find_qubit[qi] for qi in gb.qargs] + qas_ins = list(set(qbs) - set(qas)) + qbs_ins = list(set(qas) - set(qbs)) + qas2 = qas + qas_ins + qbs2 = qbs + qbs_ins + # Add additional qubits from gb + ma2 = np.kron(np.identity(1 << len(qas_ins), dtype=ma.dtype), ma) + # Add additional qubits from ga + mb2 = np.kron(np.identity(1 << len(qbs_ins), dtype=mb.dtype), mb) + # Rearrange mb2 to match ma2's qubit order + mb3 = rearrange_gate(mb2, qbs2, qas2) + return (ma2, mb3) + +def hardcoded_commutes(ga: Q.circuit.Gate, gb: Q.circuit.Gate, find_qubit: dict[Q.circuit.Bit, int]) -> int: + "Returns 0 if N/A, 1 if commute, 2 if dependent" + NA, COMMUTE, DEPENDENT = 0, 1, 2 + gan, gbn = ga.op.name, gb.op.name + qas, qbs = [find_qubit[q] for q in ga.qargs], [find_qubit[q] for q in gb.qargs] + + def comm(gan, gbn, qas, qbs): + if gan == 'cx' and gbn in ['x', 'sx', 'rx']: + return qas[1] == qbs[0] + elif gan == 'cx' and gbn in ['z', 'rz']: + return qas[0] == qbs[0] + elif gan == 'cx' and gbn == 'cx': + return qas[0] != qbs[1] and qas[1] != qbs[0] + elif gan == 'ccx' and gbn == 'cx': + return qas[2] != qbs[0] and qbs[1] not in qas[:2] + elif gan == 'ccx' and gbn == 'ccx': + return qas[2] not in qbs[:2] and qbs[2] not in qas[:2] + elif gan == 'ccx' and gbn in ['x', 'sx', 'rx']: + return qbs[0] not in qas[:2] + elif gan == 'ccx' and gbn in ['z', 'rz']: + return qbs[0] in qas[:2] + + if gan == gbn and qas == qbs: + return COMMUTE + else: + c = comm(gan, gbn, qas, qbs) + return NA if c is None else 2 - c + +def commutes(ga: Q.circuit.Gate, gb: Q.circuit.Gate, + find_qubit: dict[Q.circuit.Bit, int]) -> bool: + hc = hardcoded_commutes(ga, gb, find_qubit) or hardcoded_commutes(gb, ga, find_qubit) + if hc == 0: # we don't have this case in the hardcoded rules + ma, mb = align_gates(ga, gb, find_qubit) + return (np.abs((ma @ mb) - (mb @ ma)) < EPS).all() + return bool(hc % 2) # we've have this case in the hardcoded rules + +def read_qasm(fh) -> Q.circuit.QuantumCircuit: + #return Q.QuantumCircuit.from_qasm_file(fh) + acc = [] + for line in fh: + if not line.startswith('//'): + acc.append(line) + return Q.QuantumCircuit.from_qasm_str(''.join(acc)) + +def write_dep_graph(num_qubits: int, graph, find_qubit: dict[Q.circuit.Bit, int], fh): + nodemap = dict() + numnodes = 0 + nodes = [] + edges = [] + for node in graph.node_indices(): + data = graph.get_node_data(node) + if is_op_node(data): + nodemap[node] = numnodes + numnodes += 1 + nodes.append(data) + for fm, to in graph.edge_list(): + if is_op_node(graph.get_node_data(fm)) and is_op_node(graph.get_node_data(to)): + edges.append((nodemap[fm], nodemap[to])) + + # TODO: handle cargs + node_data = [] + for node in nodes: + if node.op.params: + node_data.append({ + 'name': node.op.name, + 'params': node.op.params, + 'qargs': [find_qubit[qarg] for qarg in node.qargs], + #'cargs': [f'{}' for carg in node.cargs] + }) + else: + node_data.append({ + 'name': node.op.name, + 'qargs': [find_qubit[qarg] for qarg in node.qargs], + #'cargs': [f'{}' for carg in node.cargs] + }) + data = {'qubits': num_qubits, 'nodes': node_data, 'edges': edges} + json.dump(data, fh) + + +def usage(argv): + return f""" +Usage: + {argv[0]} [--dag] [input.qasm] [output.json] +If either arg is omitted, read from stdin/stdout +If --dag, immediately output unprocessed DAG +""" + +def main(argv): + just_dag = False + if len(argv) == 2: + ifh = open(argv[1]) + ofh = sys.stdout + elif len(argv) == 3: + ifh = open(argv[1]) + ofh = open(argv[2], 'w') + elif len(argv) == 4 and argv[1] == '--dag': + just_dag = True + ifh = open(argv[2]) + ofh = open(argv[3], 'w') + else: + print(usage(argv).strip(), file=sys.stderr) + return 1 + circuit = read_qasm(ifh) + circuit.remove_final_measurements(True) + rb = RemoveBarriers() + # TODO: prevent commuting across barriers + circuit = rb(circuit) + dag = circuit_to_dag(circuit) + find_qubit = circuit_find_qubit_dict(circuit) + if just_dag: + write_dep_graph(circuit.num_qubits, dag._multi_graph, find_qubit, ofh) + else: + dg = dag_to_dep_graph(dag._multi_graph, find_qubit, trim=True) + trim_edges(dg) + write_dep_graph(circuit.num_qubits, dg, find_qubit, ofh) + +# def glen(generator: Iterator[Any]) -> int: +# return sum(1 for _ in generator) + +# def op_edges(g: Q.dagcircuit.DAGCircuit) -> int: +# return glen(filter(lambda e: isinstance(e[0], Q.dagcircuit.DAGOpNode) and isinstance(e[1], Q.dagcircuit.DAGOpNode), g.edges())) + +# def main(argv): +# if len(argv) == 2: +# circuit = read_qasm(argv[1]) +# my_dep = circuit_to_dag(circuit) +# qk_dag = circuit_to_dag(circuit) +# orig_depth = my_dep.depth() +# orig_edges = glen(my_dep.edges()) +# print(f"Original DAG: {orig_depth - 1} depth, {orig_edges} edges") + +# TRIM = False +# my_start = time.time() +# my_dep._multi_graph = dag_to_dep_graph(my_dep._multi_graph, circuit_find_qubit_dict(circuit), TRIM) +# my_end = time.time() + +# if TRIM: +# trim_start = time.time() +# trim_edges(my_dep._multi_graph) +# trim_end = time.time() +# print(f"My dependency graph: {my_dep.depth() - 1} depth, {op_edges(my_dep)} edges, {my_end - my_start:0.4f} sec + trimming for {trim_end - trim_start:0.4f} sec") +# else: +# print(f"My dependency graph: {my_dep.depth() - 1} depth, {op_edges(my_dep)} edges, {my_end - my_start:0.4f} sec") +# qk_start = time.time() +# qk_dep = Q.converters.dag_to_dagdependency(qk_dag) +# qk_end = time.time() + +# print(f"Qiskit dependency graph: {qk_dep.depth()} depth, {glen(qk_dep.get_all_edges())} edges, {qk_end - qk_start:0.4f} sec") +# else: +# print("Pass a .qasm file as arg", file=sys.stderr) + +if __name__ == '__main__': + exitcode = main(sys.argv) + exit(exitcode) diff --git a/feynsum-sml/src/common/schedulers/FinishQubit.sml b/feynsum-sml/src/common/schedulers/FinishQubit.sml new file mode 100644 index 0000000..5b27449 --- /dev/null +++ b/feynsum-sml/src/common/schedulers/FinishQubit.sml @@ -0,0 +1,226 @@ +functor FinishQubitScheduler + (val maxBranchingStride: int val disableFusion: bool): +sig + val scheduler: GateScheduler.t + val scheduler2: GateScheduler.t + val scheduler3: GateScheduler.t + val scheduler4: GateScheduler.t + val scheduler5: GateScheduler.t + val schedulerRandom: int -> GateScheduler.t +end = +struct + + type gate_idx = int + + fun DFS ((new, old) : (int list * int list)) = new @ old + fun BFS ((new, old) : (int list * int list)) = old @ new + + fun revTopologicalSort (dfg: DataFlowGraph.t) (push: (int list * int list) -> int list) = + let val N = Seq.length (#gates dfg) + val ind = Array.tabulate (N, Seq.length o Seq.nth (#preds dfg)) + fun decInd i = let val d = Array.sub (ind, i) in Array.update (ind, i, d - 1); d - 1 end + val queue = ref nil + (*val push = case tr of BFS => (fn xs => queue := (!queue) @ xs) | DFS => (fn xs => queue := xs @ (!queue))*) + fun pop () = case !queue of nil => NONE | x :: xs => (queue := xs; SOME x) + val _ = queue := push (List.filter (fn i => Seq.length (Seq.nth (#preds dfg) i) = 0) (List.tabulate (N, fn i => i)), !queue) + fun loop L = + case pop () of + NONE => L + | SOME n => + (let val ndeps = Seq.nth (#succs dfg) n in + queue := push (List.filter (fn m => decInd m = 0) (List.tabulate (Seq.length ndeps, Seq.nth ndeps)), !queue) + end; + loop (n :: L)) + in + loop nil + end + + fun topologicalSort (dfg: DataFlowGraph.t) (push: (int list * int list) -> int list) = + List.rev (revTopologicalSort dfg push) + + val gateDepths: int array option ref = ref NONE + + fun computeGateDepths (dfg: DataFlowGraph.t) = + let val N = Seq.length (#gates dfg) + val depths = Array.array (N, ~1) + fun gdep i = + Array.update (depths, i, 1 + Seq.reduce Int.max ~1 (Seq.map (fn j => Array.sub (depths, j)) (Seq.nth (#succs dfg) i))) + (*case Array.sub (depths, i) of + ~1 => 1 + Seq.reduce Int.min ~1 (Seq.map gateDepth (Seq.nth (#deps dfg) i)) + | d => d*) + in + List.foldl (fn (i, ()) => gdep i) () (revTopologicalSort dfg DFS); depths + end + + fun sortList (lt: 'a * 'a -> bool) (xs: 'a list) = + let fun insert (x, xs) = + case xs of + nil => x :: nil + | x' :: xs => if lt (x, x') then x :: x' :: xs else x' :: insert (x, xs) + in + List.foldr insert nil xs + end + + (* Choose in reverse topological order, sorted by easiest qubit to finish *) + fun scheduler3 dfg = + let val dfgt = DataFlowGraphUtil.transpose dfg + val depths = computeGateDepths dfg + val gib = DataFlowGraphUtil.gateIsBranching dfg + fun lt (a, b) = Array.sub (depths, a) < Array.sub (depths, b) orelse (Array.sub (depths, a) = Array.sub (depths, b) andalso not (gib a) andalso gib b) + fun push (new, old) = DFS (sortList lt new, old) + val xs = revTopologicalSort dfgt push + val N = Seq.length (#gates dfg) + val ord = Array.array (N, ~1) + fun writeOrd i xs = case xs of nil => () | x :: xs' => (Array.update (ord, x, i); writeOrd (i + 1) xs') + val _ = writeOrd 0 xs + fun pickEarliestOrd best_idx best_ord i gates = + if i = Seq.length gates then + best_idx + else let val cur_idx = Seq.nth gates i + val cur_ord = Array.sub (ord, cur_idx) + in + if cur_ord < best_ord then pickEarliestOrd cur_idx cur_ord (i + 1) gates else pickEarliestOrd best_idx best_ord (i + 1) gates + end + in + fn gates => let val g0 = Seq.nth gates 0 in + pickEarliestOrd g0 (Array.sub (ord, g0)) 1 gates + end + end + + fun gateDepth i dfg = + case !gateDepths of + NONE => let val gd = computeGateDepths dfg in + print "recompouting gate depths"; + gateDepths := SOME gd; + Array.sub (gd, i) + end + | SOME gd => Array.sub (gd, i) + + fun pickLeastDepth best_idx best_depth i gates dfg = + if i = Seq.length gates then + best_idx + else + let val cur_idx = Seq.nth gates i + val cur_depth = gateDepth cur_idx dfg in + if cur_depth < best_depth then + pickLeastDepth cur_idx cur_depth (i + 1) gates dfg + else + pickLeastDepth best_idx best_depth (i + 1) gates dfg + end + + fun pickGreatestDepth best_idx best_depth i gates dfg = + if i = Seq.length gates then + best_idx + else + let val cur_idx = Seq.nth gates i + val cur_depth = gateDepth cur_idx dfg in + if cur_depth > best_depth then + pickGreatestDepth cur_idx cur_depth (i + 1) gates dfg + else + pickGreatestDepth best_idx best_depth (i + 1) gates dfg + end + + (* From a frontier, select which gate to apply next *) + fun scheduler dfg = + (gateDepths := SOME (computeGateDepths dfg); + fn gates => let val g0 = Seq.nth gates 0 in + pickLeastDepth g0 (gateDepth g0 dfg) 1 gates dfg + end) + + (* Select gate with greatest number of descendants *) + fun scheduler4 dfg = + (gateDepths := SOME (computeGateDepths dfg); + fn gates => let val g0 = Seq.nth gates 0 in + pickGreatestDepth g0 (gateDepth g0 dfg) 1 gates dfg + end) + + structure G = Gate (structure B = BasisIdxUnlimited + structure C = Complex64) + + (* Hybrid of scheduler2 (avoid branching on unbranched qubits) and also scheduler3 (choose in reverse topological order, sorted by easiest qubit to finish) *) + fun scheduler5 (dfg: DataFlowGraph.t) = + let val gates = Seq.map G.fromGateDefn (#gates dfg) + fun touches i = #touches (Seq.nth gates i) + fun branches i = case #action (Seq.nth gates i) of G.NonBranching _ => 0 | G.MaybeBranching _ => 1 | G.Branching _ => 2 + + val dfgt = DataFlowGraphUtil.transpose dfg + val depths = computeGateDepths dfg + fun lt (a, b) = Array.sub (depths, a) < Array.sub (depths, b) orelse (Array.sub (depths, a) = Array.sub (depths, b) andalso branches a < branches b) + fun push (new, old) = DFS (sortList lt new, old) + val xs = revTopologicalSort dfgt push + val N = Seq.length (#gates dfg) + val ord = Array.array (N, ~1) + fun writeOrd i xs = case xs of nil => () | x :: xs' => (Array.update (ord, x, i); writeOrd (i + 1) xs') + val _ = writeOrd 0 xs + val touched = Array.array (#numQubits dfg, false) + fun touch i = Array.update (touched, i, true) + fun touchAll gidx = let val ts = touches gidx in List.tabulate (Seq.length ts, fn i => touch (Seq.nth ts i)); () end + fun newTouches i = + Seq.length (Seq.filter (fn j => not (Array.sub (touched, j))) (touches i)) + fun pickLeastNewTouches best_idx best_newTouches best_ord i gates = + if i = Seq.length gates then + ((* print ("Picked " ^ Int.toString best_idx ^ ", new touches " ^ Int.toString best_newTouches ^ "\n"); *) + best_idx) + else + let val cur_idx = Seq.nth gates i + val cur_newTouches = newTouches cur_idx + val cur_ord = Array.sub (ord, cur_idx) + in + if cur_newTouches < best_newTouches + orelse (cur_newTouches = best_newTouches + andalso cur_ord < best_ord) then + pickLeastNewTouches cur_idx cur_newTouches cur_ord (i + 1) gates + else + pickLeastNewTouches best_idx best_newTouches best_ord (i + 1) gates + end + in + fn gates => let val g0 = Seq.nth gates 0 + val next = pickLeastNewTouches g0 (newTouches g0) (Array.sub (ord, g0)) 1 gates + in + touchAll next; next + end + end + + (* Avoids branching on unbranched qubits *) + fun scheduler2 (dfg: DataFlowGraph.t) = + let val touched = Array.array (#numQubits dfg, false) + val gates = Seq.map G.fromGateDefn (#gates dfg) + fun touches i = #touches (Seq.nth gates i) + fun branches i = case #action (Seq.nth gates i) of G.NonBranching _ => 0 | G.MaybeBranching _ => 1 | G.Branching _ => 2 + fun touch i = Array.update (touched, i, true) + fun touchAll gidx = let val ts = touches gidx in List.tabulate (Seq.length ts, fn i => touch (Seq.nth ts i)); () end + fun newTouches i = + Seq.length (Seq.filter (fn j => not (Array.sub (touched, j))) (touches i)) + fun pickLeastNewTouches best_idx best_newTouches i gates = + if i = Seq.length gates then + ((* print ("Picked " ^ Int.toString best_idx ^ ", new touches " ^ Int.toString best_newTouches ^ "\n"); *) + best_idx) + else + let val cur_idx = Seq.nth gates i + val cur_newTouches = newTouches cur_idx + in + if cur_newTouches < best_newTouches + orelse (cur_newTouches = best_newTouches + andalso branches cur_idx < branches best_idx) then + pickLeastNewTouches cur_idx cur_newTouches (i + 1) gates + else + pickLeastNewTouches best_idx best_newTouches (i + 1) gates + end + in + fn gates => let val g0 = Seq.nth gates 0 + val next = pickLeastNewTouches g0 (newTouches g0) 1 gates + in + touchAll next; next + end + end + + val seed = Random.rand (50, 14125) + + fun schedulerRandom seedNum dfg = + (*let val seed = Random.rand (seedNum, seedNum * seedNum) in*) + fn gates => let val r = Random.randRange (0, Seq.length gates - 1) seed in + ((*print ("Randomly chose " ^ Int.toString r ^ " from range [0, " ^ Int.toString (Seq.length gates) ^ ")\n");*) + Seq.nth gates r) + end + (* end *) +end diff --git a/feynsum-sml/src/common/schedulers/GreedyBranching.sml b/feynsum-sml/src/common/schedulers/GreedyBranching.sml new file mode 100644 index 0000000..3d544a6 --- /dev/null +++ b/feynsum-sml/src/common/schedulers/GreedyBranching.sml @@ -0,0 +1,23 @@ +structure GreedyBranchingScheduler: +sig + val scheduler: GateScheduler.t +end = +struct + + type gate_idx = int + + fun pickBranching i branching gates = + if i < Seq.length gates then + (if branching i then + Seq.nth gates i + else + pickBranching (i + 1) branching gates) + else + Seq.nth gates 0 (* pick a non-branching gate *) + + (* From a frontier, select which gate to apply next *) + fun scheduler dg = + let val branching = DataFlowGraphUtil.gateIsBranching dg in + fn gates => pickBranching 0 branching gates + end +end diff --git a/feynsum-sml/src/common/schedulers/GreedyNonBranching.sml b/feynsum-sml/src/common/schedulers/GreedyNonBranching.sml new file mode 100644 index 0000000..1206fbd --- /dev/null +++ b/feynsum-sml/src/common/schedulers/GreedyNonBranching.sml @@ -0,0 +1,29 @@ +functor GreedyNonBranchingScheduler + (val maxBranchingStride: int val disableFusion: bool): +sig + val scheduler: GateScheduler.t +end = +struct + + type gate_idx = int + + type args = + { depGraph: DataFlowGraph.t + , gateIsBranching: gate_idx -> bool + } + + fun pickNonBranching i branching ftr = + if i < Seq.length ftr then + (if branching i then + pickNonBranching (i + 1) branching ftr + else + Seq.nth ftr i) + else + Seq.nth ftr 0 + + (* From a frontier, select which gate to apply next *) + fun scheduler dg = + let val branching = DataFlowGraphUtil.gateIsBranching dg in + fn gates => pickNonBranching 0 branching gates + end +end diff --git a/feynsum-sml/src/common/sources.mlb b/feynsum-sml/src/common/sources.mlb index fc6630e..99d75a8 100644 --- a/feynsum-sml/src/common/sources.mlb +++ b/feynsum-sml/src/common/sources.mlb @@ -20,23 +20,26 @@ local structure SMLQasmParser = Parser end + $(SML_LIB)/smlnj-lib/JSON/json-lib.mlb + $(SML_LIB)/smlnj-lib/Util/smlnj-lib.mlb + HashTable.sml - ApplyUntilFailure.sml + HashSet.sml Rat.sml QubitIdx.sml Constants.sml - MkComplex.sml - Complex32.sml - Complex64.sml - Complex.sml + complex/MkComplex.sml + complex/Complex32.sml + complex/Complex64.sml + complex/Complex.sml - BASIS_IDX.sml + basis/BASIS_IDX.sml ann "allowExtendedTextConsts true" in - BasisIdx64.sml - BasisIdxUnlimited.sml + basis/BasisIdx64.sml + basis/BasisIdxUnlimited.sml end ann "allowExtendedTextConsts true" in @@ -45,25 +48,30 @@ local Circuit.sml end - SPARSE_STATE_TABLE.sml - SparseStateTable.sml - SparseStateTableLockedSlots.sml + state/SPARSE_STATE_TABLE.sml + state/SparseStateTable.sml + state/SparseStateTableLockedSlots.sml - DENSE_STATE.sml - DenseState.sml - - ExpandState.sml + state/DENSE_STATE.sml + state/DenseState.sml - GateScheduler.sml - GateSchedulerNaive.sml - GateSchedulerGreedyNonBranching.sml - GateSchedulerGreedyBranching.sml - GateSchedulerGreedyFinishQubit.sml + state/HYBRID_STATE.sml + state/HybridState.sml + + state/ExpandState.sml Fingerprint.sml + + DataFlowGraph.sml + DataFlowGraphUtil.sml + GateScheduler.sml + schedulers/GreedyBranching.sml + schedulers/GreedyNonBranching.sml + schedulers/FinishQubit.sml + DynGateScheduler.sml in structure HashTable - structure ApplyUntilFailure + structure HashSet structure Rat @@ -103,17 +111,26 @@ in signature DENSE_STATE functor DenseState + + signature HYBRID_STATE + functor HybridState functor ExpandState - structure GateScheduler - structure GateSchedulerNaive - functor GateSchedulerGreedyNonBranching - structure GateSchedulerGreedyBranching - functor GateSchedulerGreedyFinishQubit - functor RedBlackMapFn functor RedBlackSetFn functor Fingerprint + + structure DataFlowGraph + structure DataFlowGraphUtil + structure GateScheduler + structure GreedyBranchingScheduler + functor GreedyNonBranchingScheduler + functor FinishQubitScheduler + (*structure GateSchedulerOrder*) + signature DYN_GATE_SCHEDULER + functor DynSchedFinishQubitWrapper + functor DynSchedInterference + functor DynSchedNaive end \ No newline at end of file diff --git a/feynsum-sml/src/common/DENSE_STATE.sml b/feynsum-sml/src/common/state/DENSE_STATE.sml similarity index 100% rename from feynsum-sml/src/common/DENSE_STATE.sml rename to feynsum-sml/src/common/state/DENSE_STATE.sml diff --git a/feynsum-sml/src/common/DenseState.sml b/feynsum-sml/src/common/state/DenseState.sml similarity index 100% rename from feynsum-sml/src/common/DenseState.sml rename to feynsum-sml/src/common/state/DenseState.sml diff --git a/feynsum-sml/src/common/ExpandState.sml b/feynsum-sml/src/common/state/ExpandState.sml similarity index 90% rename from feynsum-sml/src/common/ExpandState.sml rename to feynsum-sml/src/common/state/ExpandState.sml index 93788ab..92e8374 100644 --- a/feynsum-sml/src/common/ExpandState.sml +++ b/feynsum-sml/src/common/state/ExpandState.sml @@ -1,34 +1,31 @@ functor ExpandState (structure B: BASIS_IDX structure C: COMPLEX - structure SST: SPARSE_STATE_TABLE - structure DS: DENSE_STATE + structure HS: HYBRID_STATE structure G: GATE - sharing B = SST.B = DS.B = G.B - sharing C = SST.C = DS.C = G.C + sharing B = HS.B = G.B + sharing C = HS.C = G.C val blockSize: int val maxload: real val denseThreshold: real val pullThreshold: real) :> sig - datatype state = - Sparse of SST.t - | Dense of DS.t - | DenseKnownNonZeroSize of DS.t * int - val expand: { gates: G.t Seq.t , numQubits: int , maxNumStates: IntInf.int - , state: state + , state: HS.state , prevNonZeroSize: int } - -> {result: state, method: string, numNonZeros: int, numGateApps: int} + -> {result: HS.state, method: string, numNonZeros: int, numGateApps: int} end = struct + structure SST = HS.SST + structure DS = HS.DS + (* 0 < r < 1 * * I wish this wasn't so difficult @@ -55,7 +52,7 @@ struct val d = Char.ord (String.sub (digits, depth)) - Char.ord #"0" val _ = if 0 <= d andalso d <= 9 then () - else raise Fail "riMult: bad digit" + else raise Fail ("riMult: bad digit " ^ digits ^ ", " ^ Real.toString r ^ ", " ^ IntInf.toString i) val acc = acc + (i * IntInf.fromInt d) div (IntInf.pow (10, depth + 1)) in @@ -108,13 +105,6 @@ struct AllSucceeded | SomeFailed of {widx: B.t * C.t, gatenum: int} list - - datatype state = - Sparse of SST.t - | Dense of DS.t - | DenseKnownNonZeroSize of DS.t * int - - fun expandSparse {gates: G.t Seq.t, numQubits, state, expected} = let val numGates = Seq.length gates @@ -122,9 +112,9 @@ struct val stateSeq = case state of - Sparse sst => DelayedSeq.map SOME (SST.compact sst) - | Dense state => DS.unsafeViewContents state - | DenseKnownNonZeroSize (state, _) => DS.unsafeViewContents state + HS.Sparse sst => DelayedSeq.map SOME (SST.compact sst) + | HS.Dense state => DS.unsafeViewContents state + | HS.DenseKnownNonZeroSize (state, _) => DS.unsafeViewContents state (* number of initial elements *) val n = DelayedSeq.length stateSeq @@ -205,7 +195,7 @@ struct case apply widx of G.OutputOne widx' => doGates (apps + 1) (widx', gatenum + 1) | G.OutputTwo (widx1, widx) => - doTwo (apps + 1) ((widx1, widx), gatenum + 1) + doTwo (apps + 2) ((widx1, widx), gatenum + 1) and doTwo apps ((widx1, widx2), gatenum) = case doGates apps (widx1, gatenum) of @@ -285,7 +275,7 @@ struct val (apps, output) = loop 0 initialBlocks initialTable in - {result = Sparse output, numGateApps = apps} + {result = HS.Sparse output, numGateApps = apps} end @@ -296,9 +286,9 @@ struct val stateSeq = case state of - Sparse sst => DelayedSeq.map SOME (SST.compact sst) - | Dense state => DS.unsafeViewContents state - | DenseKnownNonZeroSize (state, _) => DS.unsafeViewContents state + HS.Sparse sst => DelayedSeq.map SOME (SST.compact sst) + | HS.Dense state => DS.unsafeViewContents state + | HS.DenseKnownNonZeroSize (state, _) => DS.unsafeViewContents state (* number of initial elements *) val n = DelayedSeq.length stateSeq @@ -331,7 +321,7 @@ struct SOME widx => doGates 0 (widx, 0) | NONE => 0) in - {result = Dense output, numGateApps = numGateApps} + {result = HS.Dense output, numGateApps = numGateApps} end @@ -343,9 +333,9 @@ struct val lookup = case state of - Sparse sst => (fn bidx => Option.getOpt (SST.lookup sst bidx, C.zero)) - | Dense ds => DS.lookupDirect ds - | DenseKnownNonZeroSize (ds, _) => DS.lookupDirect ds + HS.Sparse sst => (fn bidx => Option.getOpt (SST.lookup sst bidx, C.zero)) + | HS.Dense ds => DS.lookupDirect ds + | HS.DenseKnownNonZeroSize (ds, _) => DS.lookupDirect ds fun doGates (bidx, gatenum) = if gatenum < 0 then @@ -375,7 +365,7 @@ struct DS.pull {numQubits = numQubits} (fn bidx => doGates (bidx, numGates - 1)) in - { result = DenseKnownNonZeroSize (result, nonZeroSize) + { result = HS.DenseKnownNonZeroSize (result, nonZeroSize) , numGateApps = totalCount } end @@ -385,9 +375,9 @@ struct let val nonZeroSize = case state of - Sparse sst => SST.nonZeroSize sst - | Dense ds => DS.nonZeroSize ds - | DenseKnownNonZeroSize (_, nz) => nz + HS.Sparse sst => SST.nonZeroSize sst + | HS.Dense ds => DS.nonZeroSize ds + | HS.DenseKnownNonZeroSize (_, nz) => nz val rate = Real.max (1.0, Real.fromInt nonZeroSize / Real.fromInt prevNonZeroSize) @@ -407,7 +397,7 @@ struct val (method, {result, numGateApps}) = if - expectedCost < riMult denseThreshold maxNumStates + denseThreshold >= 1.0 orelse expectedCost < riMult denseThreshold maxNumStates then ("push sparse", expandSparse args) diff --git a/feynsum-sml/src/common/state/HYBRID_STATE.sml b/feynsum-sml/src/common/state/HYBRID_STATE.sml new file mode 100644 index 0000000..9d1f3f9 --- /dev/null +++ b/feynsum-sml/src/common/state/HYBRID_STATE.sml @@ -0,0 +1,19 @@ +signature HYBRID_STATE = +sig + structure B: BASIS_IDX + structure C: COMPLEX + structure SST: SPARSE_STATE_TABLE + structure DS: DENSE_STATE + sharing B = SST.B = DS.B + sharing C = SST.C = DS.C + + (*type t + type state = t*) + datatype state = + Sparse of SST.t + | Dense of DS.t + | DenseKnownNonZeroSize of DS.t * int + + type t = state + +end diff --git a/feynsum-sml/src/common/state/HybridState.sml b/feynsum-sml/src/common/state/HybridState.sml new file mode 100644 index 0000000..17cdcc2 --- /dev/null +++ b/feynsum-sml/src/common/state/HybridState.sml @@ -0,0 +1,20 @@ +functor HybridState + (structure B: BASIS_IDX + structure C: COMPLEX + structure SST: SPARSE_STATE_TABLE + structure DS: DENSE_STATE + sharing B = SST.B = DS.B + sharing C = SST.C = DS.C): HYBRID_STATE = +struct + structure B = B + structure C = C + structure SST = SST + structure DS = DS + + datatype state = + Sparse of SST.t + | Dense of DS.t + | DenseKnownNonZeroSize of DS.t * int + + type t = state +end diff --git a/feynsum-sml/src/common/SPARSE_STATE_TABLE.sml b/feynsum-sml/src/common/state/SPARSE_STATE_TABLE.sml similarity index 100% rename from feynsum-sml/src/common/SPARSE_STATE_TABLE.sml rename to feynsum-sml/src/common/state/SPARSE_STATE_TABLE.sml diff --git a/feynsum-sml/src/common/SparseState.sml b/feynsum-sml/src/common/state/SparseState.sml similarity index 100% rename from feynsum-sml/src/common/SparseState.sml rename to feynsum-sml/src/common/state/SparseState.sml diff --git a/feynsum-sml/src/common/SparseStateTable.sml b/feynsum-sml/src/common/state/SparseStateTable.sml similarity index 100% rename from feynsum-sml/src/common/SparseStateTable.sml rename to feynsum-sml/src/common/state/SparseStateTable.sml diff --git a/feynsum-sml/src/common/SparseStateTableLockedSlots.sml b/feynsum-sml/src/common/state/SparseStateTableLockedSlots.sml similarity index 100% rename from feynsum-sml/src/common/SparseStateTableLockedSlots.sml rename to feynsum-sml/src/common/state/SparseStateTableLockedSlots.sml diff --git a/feynsum-sml/src/main.sml b/feynsum-sml/src/main.sml index 9a0010b..ece05ec 100644 --- a/feynsum-sml/src/main.sml +++ b/feynsum-sml/src/main.sml @@ -29,63 +29,6 @@ val _ = print ("scheduler " ^ schedulerName ^ "\n") val inputName = CLA.parseString "input" "" val _ = print ("input " ^ inputName ^ "\n") -(* ======================================================================== - * gate scheduling - *) - - -local - val disableFusion = CLA.parseFlag "scheduler-disable-fusion" - val maxBranchingStride = CLA.parseInt "scheduler-max-branching-stride" 2 -in - structure GNB = - GateSchedulerGreedyNonBranching - (val maxBranchingStride = maxBranchingStride - val disableFusion = disableFusion) - - structure GFQ = - GateSchedulerGreedyFinishQubit - (val maxBranchingStride = maxBranchingStride - val disableFusion = disableFusion) - - fun print_sched_info () = - let - val _ = print - ("-------------------------------------\n\ - \--- scheduler-specific args\n\ - \-------------------------------------\n") - val _ = print - ("scheduler-max-branching-stride " ^ Int.toString maxBranchingStride - ^ "\n") - val _ = print - ("scheduler-disable-fusion? " ^ (if disableFusion then "yes" else "no") - ^ "\n") - val _ = print ("-------------------------------------\n") - in - () - end -end - - -val gateScheduler = - case schedulerName of - "naive" => GateSchedulerNaive.scheduler - - | "greedy-branching" => GateSchedulerGreedyBranching.scheduler - | "gb" => GateSchedulerGreedyBranching.scheduler - - | "greedy-nonbranching" => (print_sched_info (); GNB.scheduler) - | "gnb" => (print_sched_info (); GNB.scheduler) - - | "greedy-finish-qubit" => (print_sched_info (); GFQ.scheduler) - | "gfq" => (print_sched_info (); GFQ.scheduler) - - | _ => - Util.die - ("unknown scheduler: " ^ schedulerName - ^ - "; valid options are: naive, greedy-branching (gb), greedy-nonbranching (gnb), greedy-finish-qubit (gfq)") - (* ========================================================================= * parse input *) @@ -95,37 +38,44 @@ val _ = print \--- input-specific specs\n\ \-------------------------------\n") -val circuit = +fun parseQasm () = + let + fun handleLexOrParseError exn = + let + val e = + case exn of + SMLQasmError.Error e => e + | other => raise other + in + TerminalColorString.print + (SMLQasmError.show + {highlighter = SOME SMLQasmSyntaxHighlighter.fuzzyHighlight} e); + OS.Process.exit OS.Process.failure + end + + val ast = SMLQasmParser.parseFromFile inputName + handle exn => handleLexOrParseError exn + + val simpleCirc = SMLQasmSimpleCircuit.fromAst ast + in + DataFlowGraph.fromQasm (Circuit.fromSMLQasmSimpleCircuit simpleCirc) + end + +val dfg = case inputName of "" => Util.die ("missing: -input FILE.qasm") - | _ => - let - fun handleLexOrParseError exn = - let - val e = - case exn of - SMLQasmError.Error e => e - | other => raise other - in - TerminalColorString.print - (SMLQasmError.show - {highlighter = SOME SMLQasmSyntaxHighlighter.fuzzyHighlight} e); - OS.Process.exit OS.Process.failure - end - - val ast = SMLQasmParser.parseFromFile inputName - handle exn => handleLexOrParseError exn - - val simpleCirc = SMLQasmSimpleCircuit.fromAst ast - in - Circuit.fromSMLQasmSimpleCircuit simpleCirc - end + if String.isSuffix ".qasm" inputName then + parseQasm () + else if String.isSuffix ".json" inputName then + DataFlowGraph.fromJSONFile inputName + else + raise Fail "Unknown file suffix, use .qasm or .json" val _ = print ("-------------------------------\n") -val _ = print ("gates " ^ Int.toString (Circuit.numGates circuit) ^ "\n") -val _ = print ("qubits " ^ Int.toString (Circuit.numQubits circuit) ^ "\n") +val _ = print ("gates " ^ Int.toString (Seq.length (#gates dfg)) ^ "\n") +val _ = print ("qubits " ^ Int.toString (#numQubits dfg) ^ "\n") val showCircuit = CLA.parseFlag "show-circuit" val _ = print ("show-circuit? " ^ (if showCircuit then "yes" else "no") ^ "\n") @@ -135,9 +85,49 @@ val _ = else print ("=========================================================\n" - ^ Circuit.toString circuit + ^ DataFlowGraph.toString dfg ^ "=========================================================\n") +(* ======================================================================== + * gate scheduling + *) + + +val disableFusion = CLA.parseFlag "scheduler-disable-fusion" +val maxBranchingStride = CLA.parseInt "scheduler-max-branching-stride" 2 + +structure DGNB = + GreedyNonBranchingScheduler + (val maxBranchingStride = maxBranchingStride + val disableFusion = disableFusion) + +structure DGFQ = + FinishQubitScheduler + (val maxBranchingStride = maxBranchingStride + val disableFusion = disableFusion) + +fun print_sched_info () = + let + val _ = print + ("-------------------------------------\n\ + \--- scheduler-specific args\n\ + \-------------------------------------\n") + val _ = print + ("scheduler-max-branching-stride " ^ Int.toString maxBranchingStride + ^ "\n") + val _ = print + ("scheduler-disable-fusion? " ^ (if disableFusion then "yes" else "no") + ^ "\n") + val _ = print ("-------------------------------------\n") + in + () + end + +type gate_idx = int +type Schedule = gate_idx Seq.t + +val maxBranchingStride' = if disableFusion then 1 else maxBranchingStride + (* ======================================================================== * mains: 32-bit and 64-bit *) @@ -146,9 +136,11 @@ structure M_64_32 = MkMain (structure B = BasisIdx64 structure C = Complex32 + val maxBranchingStride = maxBranchingStride + val disableFusion = disableFusion val blockSize = blockSize val maxload = maxload - val gateScheduler = gateScheduler + val gateScheduler = schedulerName val doMeasureZeros = doMeasureZeros val denseThreshold = denseThreshold val pullThreshold = pullThreshold) @@ -157,9 +149,11 @@ structure M_64_64 = MkMain (structure B = BasisIdx64 structure C = Complex64 + val maxBranchingStride = maxBranchingStride + val disableFusion = disableFusion val blockSize = blockSize val maxload = maxload - val gateScheduler = gateScheduler + val gateScheduler = schedulerName val doMeasureZeros = doMeasureZeros val denseThreshold = denseThreshold val pullThreshold = pullThreshold) @@ -168,9 +162,11 @@ structure M_U_32 = MkMain (structure B = BasisIdxUnlimited structure C = Complex32 + val maxBranchingStride = maxBranchingStride + val disableFusion = disableFusion val blockSize = blockSize val maxload = maxload - val gateScheduler = gateScheduler + val gateScheduler = schedulerName val doMeasureZeros = doMeasureZeros val denseThreshold = denseThreshold val pullThreshold = pullThreshold) @@ -179,14 +175,16 @@ structure M_U_64 = MkMain (structure B = BasisIdxUnlimited structure C = Complex64 + val maxBranchingStride = maxBranchingStride + val disableFusion = disableFusion val blockSize = blockSize val maxload = maxload - val gateScheduler = gateScheduler + val gateScheduler = schedulerName val doMeasureZeros = doMeasureZeros val denseThreshold = denseThreshold val pullThreshold = pullThreshold) -val basisIdx64Okay = Circuit.numQubits circuit <= 62 +val basisIdx64Okay = #numQubits dfg <= 62 val main = case precision of @@ -199,4 +197,4 @@ val main = (* ======================================================================== *) -val _ = main (inputName, circuit) +val _ = main (inputName, dfg)